Skip to content

Commit 45158bb

Browse files
committed
test: more native lux functionality unblocked
1 parent d82fb52 commit 45158bb

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

test/nn/lux.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays
22

3-
function crossentropy(ŷ, y)
4-
logŷ = log.(ŷ)
5-
result = y .* logŷ
6-
return -sum(result)
7-
end
8-
93
function loss_function(model, x, y, ps, st)
104
y_hat, _ = model(x, ps, st)
11-
# return CrossEntropyLoss()(y_hat, y) # <-- needs handling of xlogx xlogy from LuxOps
12-
return crossentropy(y_hat, y)
5+
return CrossEntropyLoss()(y_hat, y)
136
end
147

158
function gradient_loss_function(model, x, y, ps, st)

0 commit comments

Comments
 (0)