Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,25 @@ impl Module for LayerNorm {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = if self.remove_mean {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
// Use mean_keepdim instead of manual division to preserve gradient flow
let mean_x = x.mean_keepdim(D::Minus1)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
// Use mean_keepdim for variance calculation to preserve gradient flow
let var_x = x.sqr()?.mean_keepdim(D::Minus1)?;

// Create epsilon as a tensor for proper gradient flow
let eps_tensor = Tensor::new(&[self.eps], x.device())?.to_dtype(internal_dtype)?;

// Use broadcast_add for adding epsilon to maintain gradient tracking
let std_x = var_x.broadcast_add(&eps_tensor)?.sqrt()?;

// Use broadcast_div for normalization
let x_normed = x.broadcast_div(&std_x)?;
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
match &self.bias {
None => Ok(x),
Expand Down
23 changes: 18 additions & 5 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,10 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
}

// NOTE: This LayerNorm CustomOp is no longer used due to gradient flow issues.
// The layer_norm() function now uses layer_norm_slow() to preserve gradients.
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct LayerNorm {
eps: f32,
}
Expand Down Expand Up @@ -881,14 +884,23 @@ pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> R
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
// Use mean_keepdim instead of manual division to preserve gradient flow
let mean_x = x.mean_keepdim(D::Minus1)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
// Use mean_keepdim for variance calculation to preserve gradient flow
let var_x = x.sqr()?.mean_keepdim(D::Minus1)?;

// Create epsilon as a tensor for proper gradient flow
let eps_tensor = Tensor::new(&[eps as f64], x.device())?.to_dtype(internal_dtype)?;

// Use broadcast_add for adding epsilon to maintain gradient tracking
let std_x = var_x.broadcast_add(&eps_tensor)?.sqrt()?;

// Use broadcast_div for normalization
let x_normed = x.broadcast_div(&std_x)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(alpha)?
Expand All @@ -907,7 +919,8 @@ pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Resul
beta.shape()
)
}
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
// Use the gradient-preserving slow implementation to maintain backward pass
layer_norm_slow(xs, alpha, beta, eps)
}

// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
Expand Down
112 changes: 110 additions & 2 deletions candle-nn/tests/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

use anyhow::Result;
use candle::{test_utils, Device, Tensor};
use candle_nn::{LayerNorm, Module};
use candle::{test_utils, Device, Tensor, DType};
use candle_nn::{LayerNorm, Module, VarBuilder, VarMap};

#[test]
fn layer_norm() -> Result<()> {
Expand Down Expand Up @@ -53,3 +53,111 @@ fn layer_norm() -> Result<()> {
);
Ok(())
}

#[test]
fn test_layernorm_gradient_flow() -> Result<()> {
// Test that LayerNorm properly propagates gradients to all parameters
let device = &Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);

// Create a simple model: Linear -> LayerNorm -> Linear
let hidden_size = 64;
let batch_size = 4;

// Build model components
let linear1 = candle_nn::linear(hidden_size, hidden_size, vb.pp("linear1"))?;
let layer_norm = candle_nn::layer_norm(
hidden_size,
candle_nn::LayerNormConfig::default(),
vb.pp("layer_norm")
)?;
let linear2 = candle_nn::linear(hidden_size, hidden_size, vb.pp("linear2"))?;

// Create input and target
let input = Tensor::randn(0f32, 1.0, (batch_size, hidden_size), device)?;
let target = Tensor::randn(0f32, 1.0, (batch_size, hidden_size), device)?;

// Forward pass
let x1 = linear1.forward(&input)?;
let x_norm = layer_norm.forward(&x1)?;
let output = linear2.forward(&x_norm)?;

// Compute loss (MSE)
let loss = (output.sub(&target))?.sqr()?.mean_all()?;

// Backward pass
let grads = loss.backward()?;

// Check gradient flow
let vars = varmap.all_vars();
let mut params_with_gradients = 0;
let mut _params_without_gradients = 0;

for var in &vars {
if let Some(grad) = grads.get(var) {
let grad_norm = grad.sqr()?.sum_all()?.sqrt()?.to_scalar::<f32>()?;
if grad_norm > 1e-8 {
params_with_gradients += 1;
} else {
_params_without_gradients += 1;
}
} else {
_params_without_gradients += 1;
}
}

let gradient_flow_pct = (params_with_gradients as f32 / vars.len() as f32) * 100.0;
println!("Gradient flow: {:.1}% ({}/{} parameters)",
gradient_flow_pct, params_with_gradients, vars.len());

// With the fix, we should have 100% gradient flow
assert!(gradient_flow_pct > 90.0,
"Gradient flow too low: {:.1}% (expected > 90%)", gradient_flow_pct);

Ok(())
}

#[test]
fn test_layernorm_numerical_equivalence() -> Result<()> {
// Test that the fixed implementation produces the same numerical results
let device = &Device::Cpu;

// Test with various input shapes and values
let test_cases = vec![
(vec![1, 3], vec![1f32, 2., 3.]),
(vec![2, 4], vec![1f32, 2., 3., 4., 5., 6., 7., 8.]),
(vec![1, 2, 3], vec![1f32, 2., 3., 4., 5., 6.]),
];

for (shape, data) in test_cases {
let input = Tensor::new(data.as_slice(), device)?.reshape(shape.as_slice())?;

// Create LayerNorm with known parameters
let normalized_shape = *shape.last().unwrap();
let weight = Tensor::ones(normalized_shape, DType::F32, device)?;
let bias = Tensor::zeros(normalized_shape, DType::F32, device)?;
let eps = 1e-5;

let layer_norm = LayerNorm::new(weight, bias, eps);
let output = layer_norm.forward(&input)?;

// Verify the output has the expected properties:
// 1. Same shape as input
assert_eq!(output.shape(), input.shape());

// 2. Mean should be approximately zero (within numerical precision)
let mean = output.mean_keepdim(candle::D::Minus1)?;
let mean_abs_max = mean.abs()?.flatten_all()?.max(0)?.to_scalar::<f32>()?;
assert!(mean_abs_max < 1e-5, "Mean not close to zero: {}", mean_abs_max);

// 3. Variance should be approximately 1 (check first element as example)
let centered = output.broadcast_sub(&mean)?;
let var = centered.sqr()?.mean_keepdim(candle::D::Minus1)?;
let var_flat = var.flatten_all()?;
let var_val = var_flat.get(0)?.to_scalar::<f32>()?;
assert!((var_val - 1.0).abs() < 1e-4, "Variance not close to 1: {}", var_val);
}

Ok(())
}