Skip to content

Commit 93856cd

Browse files
committed
test: try using MSELoss directly
1 parent ae8a5b7 commit 93856cd

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

src/helpers/losses.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3}
120120
T = promote_type(T1, T2, T3)
121121
diff = x - y
122122
abs_diff = abs(diff)
123-
return ifelse(abs_diff δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ))
123+
return ifelse(
124+
abs_diff δ, convert(T, 0.5) * abs2(diff), δ * (abs_diff - convert(T, 0.5) * δ))
124125
end
125126
has_custom_derivative(::typeof(huber_loss)) = true
126127
function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3}

test/reactant/training_tests.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
Reactant.set_default_backend("cpu")
1414
end
1515

16-
# TODO: Test for compute_gradients
17-
1816
xdev = xla_device(; force=true)
1917

2018
@testset "MLP Training: $(version)" for version in (:iip, :oop)
@@ -41,17 +39,13 @@
4139

4240
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
4341

44-
# FIXME: Use MSELoss <-- currently fails due to Enzyme
45-
function sse(model, ps, st, (x, y))
46-
z, stₙ = model(x, ps, st)
47-
return sum(abs2, z .- y), stₙ, (;)
48-
end
49-
5042
for epoch in 1:100, (xᵢ, yᵢ) in dataloader
5143
grads, loss, stats, train_state = if version === :iip
52-
Training.single_train_step!(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state)
44+
Training.single_train_step!(
45+
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
5346
elseif version === :oop
54-
Training.single_train_step(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state)
47+
Training.single_train_step(
48+
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
5549
else
5650
error("Invalid version: $(version)")
5751
end
@@ -64,7 +58,5 @@
6458

6559
@test total_final_loss < 100 * total_initial_loss
6660
end
67-
68-
# TODO: Training a CNN
6961
end
7062
end

0 commit comments

Comments
 (0)