Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit cc3fbe9

Browse files
committed
Disable gpu on DON example
1 parent 0b7e030 commit cc3fbe9

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

example/Burgers/src/Burgers_deeponet.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
function train_don()
2-
if has_cuda()
3-
@info "CUDA is on"
4-
device = gpu
5-
CUDA.allowscalar(false)
6-
else
2+
# if has_cuda()
3+
# @info "CUDA is on"
4+
# device = gpu
5+
# CUDA.allowscalar(false)
6+
# else
77
device = cpu
8-
end
8+
# end
99

1010
x, y = get_data_don(n=300)
1111
xtrain = x[1:280, :]' |> device
12-
xval = x[end-19:end, :]' |device
12+
xval = x[end-19:end, :]' |> device
1313

1414
ytrain = y[1:280, :] |> device
1515
yval = y[end-19:end, :] |> device
@@ -20,7 +20,7 @@ function train_don()
2020
opt = ADAM(learning_rate)
2121

2222
m = DeepONet((1024,1024,1024),(1,1024,1024),gelu,gelu)
23-
loss(xtrain,ytrain,sensor) = Flux.Losses.mse(model(xtrain,sensor),ytrain)
23+
loss(xtrain,ytrain,sensor) = Flux.Losses.mse(m(xtrain,sensor),ytrain)
2424
evalcb() = @show(loss(xval,yval,grid))
2525

2626
Flux.@epochs 400 Flux.train!(loss, params(m), [(xtrain,ytrain,grid)], opt, cb = evalcb)

0 commit comments

Comments
 (0)