|
13 | 13 | Reactant.set_default_backend("cpu") |
14 | 14 | end |
15 | 15 |
|
16 | | - # TODO: Test for compute_gradients |
17 | | - |
18 | 16 | xdev = xla_device(; force=true) |
19 | 17 |
|
20 | 18 | @testset "MLP Training: $(version)" for version in (:iip, :oop) |
|
41 | 39 |
|
42 | 40 | train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) |
43 | 41 |
|
44 | | - # FIXME: Use MSELoss <-- currently fails due to Enzyme |
45 | | - function sse(model, ps, st, (x, y)) |
46 | | - z, stₙ = model(x, ps, st) |
47 | | - return sum(abs2, z .- y), stₙ, (;) |
48 | | - end |
49 | | - |
50 | 42 | for epoch in 1:100, (xᵢ, yᵢ) in dataloader |
51 | 43 | grads, loss, stats, train_state = if version === :iip |
52 | | - Training.single_train_step!(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state) |
| 44 | + Training.single_train_step!( |
| 45 | + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) |
53 | 46 | elseif version === :oop |
54 | | - Training.single_train_step(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state) |
| 47 | + Training.single_train_step( |
| 48 | + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) |
55 | 49 | else |
56 | 50 | error("Invalid version: $(version)") |
57 | 51 | end |
|
64 | 58 |
|
65 | 59 | @test total_final_loss < 100 * total_initial_loss |
66 | 60 | end |
67 | | - |
68 | | - # TODO: Training a CNN |
69 | 61 | end |
70 | 62 | end |
0 commit comments