Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 0c7ac83

Browse files
committed
Work around the AMDGPU issue
1 parent c09513c commit 0c7ac83

File tree

7 files changed

+33
-8
lines changed

7 files changed

+33
-8
lines changed

.buildkite/pipeline.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ steps:
77
test_args: "--quickfail"
88
- JuliaCI/julia-coverage#v1:
99
codecov: true
10+
dirs:
11+
- src
12+
- ext
1013
agents:
1114
queue: "juliagpu"
1215
cuda: "*"
@@ -27,6 +30,9 @@ steps:
2730
test_args: "--quickfail"
2831
- JuliaCI/julia-coverage#v1:
2932
codecov: true
33+
dirs:
34+
- src
35+
- ext
3036
env:
3137
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
3238
JULIA_AMDGPU_HIP_MUST_LOAD: "1"

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ jobs:
4141
RETESTITEMS_NWORKERS: 4
4242
RETESTITEMS_NWORKER_THREADS: 2
4343
- uses: julia-actions/julia-processcoverage@v1
44+
with:
45+
directories: src,ext
4446
- uses: codecov/codecov-action@v4
4547
with:
4648
files: lcov.info

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1616
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1717
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1818

19+
[weakdeps]
20+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
21+
22+
[extensions]
23+
LuxNeuralOperatorsAMDGPUExt = "AMDGPU"
24+
1925
[compat]
26+
AMDGPU = "0.9.5"
2027
Aqua = "0.8.7"
2128
ArgCheck = "2.3.0"
2229
ChainRulesCore = "1.24.0"

ext/LuxNeuralOperatorsAMDGPUExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module LuxNeuralOperatorsAMDGPUExt
2+
3+
using AMDGPU: AnyROCArray
4+
using LuxNeuralOperators: LuxNeuralOperators
5+
6+
# This should be upstreamed to NNlib before we release this package
7+
@inline function LuxNeuralOperators.__batched_mul(
8+
x::AnyROCArray{<:Union{ComplexF16, ComplexF32, ComplexF64}, 3},
9+
y::AnyROCArray{<:Union{ComplexF16, ComplexF32, ComplexF64}, 3})
10+
# FIXME: This is not good for performance but that is okay for now
11+
return stack(*, eachslice(x; dims=3), eachslice(y; dims=3))
12+
end
13+
14+
end

src/LuxNeuralOperators.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using PrecompileTools: @recompile_invalidations
99
using FFTW: FFTW, irfft, rfft
1010
using Lux
1111
using LuxCore: LuxCore, AbstractExplicitLayer
12-
using NNlib: NNlib, batched_transpose,
12+
using NNlib: NNlib,
1313
using Random: Random, AbstractRNG
1414
using Reexport: @reexport
1515
end
@@ -21,6 +21,7 @@ const CRC = ChainRulesCore
2121
const True = Val(true)
2222
const False = Val(false)
2323

24+
include("utils.jl")
2425
include("transform.jl")
2526

2627
include("functional.jl")

src/functional.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,6 @@ end
1717
return reshape(x_weighted, x_size[1:(N - 2)]..., size(x_weighted)[2:3]...)
1818
end
1919

20-
@inline function __apply_pattern_batched_mul(
21-
x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2}
22-
x_ = batched_transpose(x) # i x b x m
23-
res = y x_ # o x b x m
24-
return batched_transpose(res) # m x o x b
25-
end
26-
2720
@inline __pad_modes(x, dims::Integer...) = __pad_modes(x, dims)
2821
@inline __pad_modes(x, dims::NTuple) = __pad_modes!(similar(x, dims), x)
2922

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Temporarily capture certain calls like AMDGPU for ComplexFloats
2+
@inline __batched_mul(x, y) = x y

0 commit comments

Comments
 (0)