Skip to content

Commit aadfc7f

Browse files
committed
test: try adjusting precision
1 parent 534d3e4 commit aadfc7f

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

test/nn/lux.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,15 @@ end
7474

7575
res, dps = gradient_loss_function(model, noisy, target, ps, st)
7676

77-
compiled_gradient =
78-
Reactant.with_config(; dot_general_precision=PrecisionConfig.HIGHEST) do
79-
Reactant.compile(gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst2))
80-
end
77+
dot_general_precision = if contains(string(Reactant.devices()[1]), "CUDA")
78+
PrecisionConfig.HIGHEST
79+
else
80+
PrecisionConfig.DEFAULT
81+
end
82+
83+
compiled_gradient = Reactant.with_config(; dot_general_precision) do
84+
@compile gradient_loss_function(cmodel, cnoisy, ctarget, cps, cst2)
85+
end
8186

8287
res_reactant, dps_reactant = compiled_gradient(cmodel, cnoisy, ctarget, cps, cst2)
8388

test/runtests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5656
# @safetestset "NNlib Primitives" include("nn/nnlib.jl")
5757
# @safetestset "Flux.jl Integration" include("nn/flux.jl")
5858
if Sys.islinux()
59-
@safetestset "LuxLib Primitives" include("nn/luxlib.jl")
60-
@info "LuxLib Primitives tests finished"
59+
# @safetestset "LuxLib Primitives" include("nn/luxlib.jl")
6160
@safetestset "Lux Integration" include("nn/lux.jl")
62-
@info "Lux Integration tests finished"
6361
end
6462
end
6563
end

0 commit comments

Comments
 (0)