Skip to content

Commit 577e874

Browse files
authored
fix: Fix wrong condition in smooth l1 loss (#213)
* fix wrong condition in smooth l1 loss * clippy & fmt * add unit test for smooth l1 loss with negative diff * clippy and bump up patch version * update unit test
1 parent e993b65 commit 577e874

File tree

3 files changed

+30
-18
lines changed

3 files changed

+30
-18
lines changed

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[workspace.package]
22
# All but mlx-sys should follow the same version. mlx-sys should follow
33
# the version of mlx-c.
4-
version = "0.21.1"
4+
version = "0.21.2"
55
edition = "2021"
66
authors = [
77
"Minghua Wu <michael.wu1107@gmail.com>",
@@ -29,9 +29,9 @@ resolver = "2"
2929
[workspace.dependencies]
3030
# workspace local dependencies
3131
mlx-sys = { version = "=0.1.0", path = "mlx-sys" }
32-
mlx-macros = { version = "0.21.0", path = "mlx-macros" }
33-
mlx-internal-macros = { version = "0.21.0", path = "mlx-internal-macros" }
34-
mlx-rs = { version = "0.21.0", path = "mlx-rs" }
32+
mlx-macros = { version = "0.21", path = "mlx-macros" }
33+
mlx-internal-macros = { version = "0.21", path = "mlx-internal-macros" }
34+
mlx-rs = { version = "0.21.2", path = "mlx-rs" }
3535

3636
# external dependencies
3737
thiserror = "1"

mlx-rs/src/losses.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -523,13 +523,12 @@ impl SmoothL1Loss {
523523
let reduction = self.reduction;
524524

525525
check_shape(predictions, targets, "predictions", "targets")?;
526-
let diff = predictions.subtract(targets)?;
526+
let diff = predictions.subtract(targets)?.abs()?;
527+
let beta = array!(beta);
527528
let loss = r#where(
528-
&diff.lt(array!(beta))?,
529-
array!(0.5)
530-
.multiply(square(&diff)?)?
531-
.divide(&array!(beta))?,
532-
abs(&diff)?.subtract(array!(0.5).multiply(array!(beta))?)?,
529+
&diff.lt(&beta)?,
530+
array!(0.5).multiply(square(&diff)?)?.divide(&beta)?,
531+
diff.subtract(array!(0.5).multiply(beta)?)?,
533532
)?;
534533
reduction.reduce(loss)
535534
}
@@ -1189,6 +1188,18 @@ mod tests {
11891188
assert_array_eq!(loss_mean, expected_mean);
11901189
}
11911190

1191+
#[test]
1192+
fn test_smooth_l1_loss_negative_diff() {
1193+
let a = array!([1.5, 6.0, 0.5, 2.5]);
1194+
let b = array!([1.0, 2.0, 0.5, 3.5]);
1195+
1196+
let loss = SmoothL1Loss::new();
1197+
1198+
let ab = loss.apply(&a, &b).unwrap();
1199+
let ba = loss.apply(&b, &a).unwrap();
1200+
assert_array_eq!(ab, ba);
1201+
}
1202+
11921203
#[test]
11931204
fn test_nll_loss() {
11941205
let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);

mlx-rs/src/transforms/mod.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,22 @@ fn jvp_inner(
7272

7373
/// Compute the Jacobian-vector product.
7474
///
75-
/// This computes the product of the Jacobian of a function `f` evaluated at `primals` with the
76-
/// `tangents`.
75+
/// This computes the product of the Jacobian of a function `f` evaluated at
76+
/// `primals` with the `tangents`.
7777
///
7878
/// # Params:
7979
///
80-
/// - `f`: function which takes an array of `Array` and returns an array of `Array`
80+
/// - `f`: function which takes an array of `Array` and returns an array of
81+
/// `Array`
8182
/// - `primals`: array of `Array` at which to evaluate the Jacobian
82-
/// - `tangents`: array of `Array` which are the "vector" in the Jacobian-vector product. The
83-
/// `tangents` should be the same in number, shape and type as the inputs of `f`, e.g. the
84-
/// `primals`
83+
/// - `tangents`: array of `Array` which are the "vector" in the Jacobian-vector
84+
/// product. The `tangents` should be the same in number, shape and type as
85+
/// the inputs of `f`, e.g. the `primals`
8586
///
8687
/// # Returns:
8788
///
88-
/// Array of the Jacobian-vector products which is the same in number, shape and type of
89-
/// the outputs of `f`
89+
/// Array of the Jacobian-vector products which is the same in number, shape and
90+
/// type of the outputs of `f`
9091
pub fn jvp<'a, F>(f: F, primals: &[Array], tangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
9192
where
9293
F: FnMut(&[Array]) -> Vec<Array> + 'a,

0 commit comments

Comments
 (0)