Skip to content

Flux : slower and slower computations of a neural network - a VRAM problem ? #654

@Vintarel

Description

@Vintarel

This small code:

input = rand(Float32, 6, 6, 2, 1) |> gpu  
output = nn(input) |> cpu  

is slower and slower to compute on my computer. Namely, the full code:

using Flux, AMDGPU

nn = Chain(
    Conv((3,3), 2=>128, relu, pad=0),
    Conv((3,3), 128=>128, relu, pad=1),
    Conv((1,1), 128=>1, relu),
    Flux.flatten,
    Dense(16 => 256, relu),
    Dense(256=>1, sigmoid)
) |> gpu

for i in 1:20
  @time for j in 1:2000
      input = rand(Float32, 6, 6, 2, 1) |> gpu
      output = nn(input) |> cpu
  end
end

prints the following computation times:

  9.989851 seconds (13.81 M allocations: 876.953 MiB, 2.33% gc time, 80.64% compilation time)
  1.678728 seconds (880.00 k allocations: 36.392 MiB)
  3.435393 seconds (1.03 M allocations: 40.158 MiB, 0.55% gc time)
  3.878582 seconds (1.04 M allocations: 40.348 MiB, 0.48% gc time)
  4.033464 seconds (1.04 M allocations: 40.417 MiB, 0.58% gc time)
  4.567139 seconds (880.00 k allocations: 36.392 MiB)
  7.569016 seconds (1.04 M allocations: 40.452 MiB, 0.26% gc time)
  6.075572 seconds (1.04 M allocations: 40.448 MiB, 0.30% gc time)
  6.280088 seconds (1.04 M allocations: 40.448 MiB, 0.30% gc time)
  7.004467 seconds (1.04 M allocations: 40.448 MiB, 0.27% gc time)
  6.084493 seconds (880.01 k allocations: 36.392 MiB)
  8.216904 seconds (1.04 M allocations: 40.449 MiB, 0.23% gc time)
  9.743433 seconds (1.04 M allocations: 40.449 MiB, 0.20% gc time)
 10.631787 seconds (1.04 M allocations: 40.449 MiB, 0.18% gc time)
  9.975057 seconds (880.01 k allocations: 36.392 MiB)
 11.235186 seconds (1.04 M allocations: 40.449 MiB, 0.17% gc time)
 20.558719 seconds (1.04 M allocations: 40.449 MiB, 0.09% gc time)
 25.954910 seconds (1.04 M allocations: 40.449 MiB, 0.07% gc time)
 19.961299 seconds (880.01 k allocations: 36.392 MiB)
 17.356931 seconds (1.04 M allocations: 40.449 MiB, 0.11% gc time)

so it goes from 3sec to more than 20sec, for the same part of code ! I checked the VRAM of my GPU card (with cat /sys/class/drm/card1/device/mem_info_vram_used) and is strictly increasing during the computation of the code above. Maybe this is the source of the problem ? But I'm unable to empty it.

I tried many trivial things such as finalize(input), but I was not able to solve the problem. The cpu version of it works well. Please help !

The GPU card is AMD Radeon RX 6700 XT, I am on Manjaro, kernel 6.9.9-1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions