@@ -3,15 +3,23 @@ using Random, Statistics
33using Flux. Losses: logitcrossentropy
44using Flux: onecold
55using HuggingFaceDatasets
6+ using MLUtils
7+ using ImageCore
68# using ProfileView, BenchmarkTools
79
8- function mnist_transform (x )
9- x = py2jl (x)
10- image = x[ " image" ] ./ 255f0
11- label = Flux. onehotbatch (x [" label" ], 0 : 9 )
10+ function mnist_transform (batch )
11+ image = ImageCore . channelview .(batch[ " image " ]) # from Matrix{Gray{N0f8}} to Matrix{UInt8}
12+ image = Flux . batch ( image) ./ 255f0
13+ label = Flux. onehotbatch (batch [" label" ], 0 : 9 )
1214 return (; image, label)
1315end
1416
17+ # Remove when https://github.com/JuliaML/MLUtils.jl/pull/147 is merged and tagged
18+ Base. getindex (data:: MLUtils.MappedData , idx:: Int ) = getobs (data. f (getobs (data. data, [idx])), 1 )
19+ Base. getindex (data:: MLUtils.MappedData , idxs:: AbstractVector ) = data. f (getobs (data. data, idxs))
20+ Base. getindex (data:: MLUtils.MappedData , :: Colon ) = data[1 : length (data. data)]
21+
22+
1523function loss_and_accuracy (data_loader, model, device)
1624 acc = 0
1725 ls = 0.0f0
2937function train (epochs)
3038 batchsize = 128
3139 nhidden = 100
32- device = gpu
33-
34- dataset = load_dataset (" mnist" )
35- set_format! (dataset, " julia" )
36- set_jltransform! (dataset, mnist_transform)
37-
38- # We use [:] to materialize and transform the whole dataset.
39- # This gives much faster iterations.
40- # Omit the [:] if you don't want to load the whole dataset in-memory.
41- train_loader = Flux. DataLoader (dataset[" train" ][:]; batchsize, shuffle= true )
42- test_loader = Flux. DataLoader (dataset[" test" ][:]; batchsize)
40+ device = cpu
4341
42+ train_data = load_dataset (" mnist" , split= " train" ). with_format (" julia" )
43+ test_data = load_dataset (" mnist" , split= " test" ). with_format (" julia" )
44+ train_data = mapobs (mnist_transform, train_data)[:] # lazy apply transform then materialize
45+ test_data = mapobs (mnist_transform, test_data)[:]
46+
47+ train_loader = Flux. DataLoader (train_data; batchsize, shuffle= true )
48+ test_loader = Flux. DataLoader (test_data; batchsize)
49+
4450 model = Chain ([Flux. flatten,
4551 Dense (28 * 28 , nhidden, relu),
4652 Dense (nhidden, nhidden, relu),
@@ -57,7 +63,7 @@ function train(epochs)
5763 end
5864
5965 report (0 )
60- for epoch in 1 : epochs
66+ @time for epoch in 1 : epochs
6167 for (x, y) in train_loader
6268 x, y = x |> device, y |> device
6369 loss, grads = withgradient (model -> logitcrossentropy (model (x), y), model)
0 commit comments