Skip to content

Conversation

@Koratahiu
Copy link
Contributor

An experimental and sketch feature for OFT

Example:

My OFT:

OFT Block Size automatically adjusted for 17 layers. Changes:
1 layer from 128 to 32
12 layers from 128 to 120
2 layers from 128 to 135
2 layers from 128 to 160

The Issue

  • Block Size 160: Summing up interactions from 160 parameters. This results in a large variance (strong rotation).
  • Block Size 32: Summing up interactions from 32 parameters. This results in a small variance (weak rotation).

If we leave it alone, the 160-block layers will dominate the training, and the 32-block layers will effectively be frozen because their gradients will be tiny in comparison.

How to Align (The Solution)

To make the 32-block layer behave like the 160-block layer, we need to boost the 32-block layer.

We want the "Energy" (Variance) of the small block to match the large block.

$$ \text{Scale}^2 \times \text{Size}_{small} = \text{Size}_{large} $$

$$ \text{Scale} = \sqrt{\frac{\text{Size}_{large}}{\text{Size}_{small}}} $$

  • We have here Scale^2 due to OFT math:

$$ R \approx I + 2Q + 2Q^2 + \dots $$

The Calculation

  • Target (Max): 160
  • Current: 32

$$ \text{Scale} = \sqrt{\frac{160}{32}} = \sqrt{5} \approx 2.23 $$

By multiplying the weights of the 32-block layer by 2.23, we ensure it rotates the inputs with the same intensity as the 160-block layer.

TODO

  • To be tested
  • Is this correct?

@Koratahiu
Copy link
Contributor Author

Update 1:

The scale is only ^2 during initialization
After that it shows a linear relationship (just like LoRA)
So I changed it to linear scaling

A small repro to prove this issue + the linear relationship:

import torch
import torch.nn as nn
import torch.optim as optim

def simple_cayley(Q):
    """Simple Cayley transform: (I+Q)^-1 (I-Q)"""
    I = torch.eye(Q.shape[-1], device=Q.device)
    return torch.linalg.solve(I + Q, I - Q)

def run_oft_sim(block_size, dim=64, steps=100, lr=1e-5):
    torch.manual_seed(42)
    
    X = torch.randn(8, dim)
    W_orig = torch.randn(dim, dim)
    Y_target = torch.randn(8, dim)
    num_blocks = dim // block_size
    num_params = block_size * (block_size - 1) // 2
    params = nn.Parameter(torch.zeros(num_blocks, num_params))
    triu_indices = torch.triu_indices(block_size, block_size, 1)

    opt = optim.SGD([params], lr=lr)

    for _ in range(steps):
        opt.zero_grad()
        
        # Construct Skew-Symmetric Q for each block
        Q_blocks = torch.zeros(num_blocks, block_size, block_size)
        Q_blocks[:, triu_indices[0], triu_indices[1]] = params
        Q_blocks = Q_blocks - Q_blocks.transpose(1, 2)
        
        # Apply Cayley to get R for each block
        R_list = [simple_cayley(Q_blocks[i]) for i in range(num_blocks)]
        
        # Build full Block Diagonal R
        R = torch.block_diag(*R_list)
        
        # Rotate Weights (W_new = R @ W)
        W_new = R @ W_orig
        
        # Forward & Loss
        Y_pred = X @ W_new.t()
        loss = (Y_pred - Y_target).pow(2).mean()
        
        loss.backward()
        opt.step()

    with torch.no_grad():
        Q_blocks = torch.zeros(num_blocks, block_size, block_size)
        Q_blocks[:, triu_indices[0], triu_indices[1]] = params
        Q_blocks = Q_blocks - Q_blocks.transpose(1, 2)
        norm = torch.norm(Q_blocks, p='fro', dim=(1,2)).mean().item()
        
    return norm

if __name__ == "__main__":
    print("Running OFT Gradient Dynamics Test...")
    print("-" * 40)
    
    norm_16 = run_oft_sim(block_size=16)
    norm_32 = run_oft_sim(block_size=32)
    
    ratio = norm_32 / norm_16
    sqrt_prediction = 32**0.5 / 16**0.5
    linear_prediction = 32 / 16
    
    print(f"Norm (Block=16): {norm_16:.5f}")
    print(f"Norm (Block=32): {norm_32:.5f}")
    print("-" * 40)
    print(f"Ratio (32/16):   {ratio:.3f}")
    print(f"Sqrt Prediction: {sqrt_prediction:.3f}")
    print(f"Linear Pred.:    {linear_prediction:.3f}")
    print("-" * 40)
    
    if abs(ratio - 2.0) < abs(ratio - 1.414):
        print("evidence supports LINEAR scaling.")
    else:
        print("evidence supports SQRT scaling.")

The log shows:

Running OFT Gradient Dynamics Test...
----------------------------------------
Norm (Block=16): 0.00437
Norm (Block=32): 0.00874
----------------------------------------
Ratio (32/16):   1.999
Sqrt Prediction: 1.414
Linear Pred.:    2.000
----------------------------------------
evidence supports LINEAR scaling.

@Koratahiu Koratahiu changed the title [OFT] Sqrt scaling for constant learning [OFT] Linear scaling for constant learning Dec 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant