Skip to content

Commit d245ae5

Browse files
authored
feat: batchnorm ops (#1336)
* feat: batchnorm ops * fix: use jll with bn grad fix * test: bn * feat: add the grad op as well
1 parent 0139886 commit d245ae5

File tree

3 files changed

+228
-1
lines changed

3 files changed

+228
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.10"
90-
Reactant_jll = "0.0.188"
90+
Reactant_jll = "0.0.189"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

src/Ops.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2968,4 +2968,130 @@ end
29682968
]
29692969
end
29702970

2971+
@noinline function batch_norm_inference(
2972+
operand::TracedRArray{T,N},
2973+
scale::Union{TracedRArray{T,1},Nothing},
2974+
offset::Union{TracedRArray{T,1},Nothing},
2975+
mean::TracedRArray{T,1},
2976+
variance::TracedRArray{T,1};
2977+
epsilon,
2978+
feature_index::Int64,
2979+
location=mlir_stacktrace("batch_norm_inference", @__FILE__, @__LINE__),
2980+
) where {T,N}
2981+
len = size(operand, feature_index)
2982+
@assert length(mean) == length(variance) == len
2983+
2984+
if scale === nothing
2985+
scale = fill(T(1), len; location)
2986+
else
2987+
@assert size(scale) == (len,)
2988+
end
2989+
2990+
if offset === nothing
2991+
offset = fill(T(0), len; location)
2992+
else
2993+
@assert size(offset) == (len,)
2994+
end
2995+
2996+
return TracedRArray{T,N}(
2997+
(),
2998+
MLIR.IR.result(
2999+
stablehlo.batch_norm_inference(
3000+
operand.mlir_data,
3001+
scale.mlir_data,
3002+
offset.mlir_data,
3003+
mean.mlir_data,
3004+
variance.mlir_data;
3005+
epsilon=Float32(epsilon),
3006+
feature_index=feature_index - 1,
3007+
location,
3008+
),
3009+
1,
3010+
),
3011+
size(operand),
3012+
)
3013+
end
3014+
3015+
@noinline function batch_norm_training(
3016+
operand::TracedRArray{T,N},
3017+
scale::Union{TracedRArray{T,1},Nothing},
3018+
offset::Union{TracedRArray{T,1},Nothing};
3019+
epsilon,
3020+
feature_index::Int64,
3021+
location=mlir_stacktrace("batch_norm_training", @__FILE__, @__LINE__),
3022+
) where {T,N}
3023+
len = size(operand, feature_index)
3024+
3025+
if scale === nothing
3026+
scale = fill(T(1), len; location)
3027+
else
3028+
@assert size(scale) == (len,)
3029+
end
3030+
3031+
if offset === nothing
3032+
offset = fill(T(0), len; location)
3033+
else
3034+
@assert size(offset) == (len,)
3035+
end
3036+
3037+
batch_norm_train_op = stablehlo.batch_norm_training(
3038+
operand.mlir_data,
3039+
scale.mlir_data,
3040+
offset.mlir_data;
3041+
epsilon=Float32(epsilon),
3042+
feature_index=feature_index - 1,
3043+
location,
3044+
)
3045+
3046+
return (
3047+
TracedRArray{T,N}((), MLIR.IR.result(batch_norm_train_op, 1), size(operand)),
3048+
TracedRArray{T,1}((), MLIR.IR.result(batch_norm_train_op, 2), (len,)),
3049+
TracedRArray{T,1}((), MLIR.IR.result(batch_norm_train_op, 3), (len,)),
3050+
)
3051+
end
3052+
3053+
@noinline function batch_norm_grad(
3054+
operand::TracedRArray{T,N},
3055+
scale::Union{TracedRArray{T,1},Nothing},
3056+
mean::TracedRArray{T,1},
3057+
variance::TracedRArray{T,1},
3058+
grad_output::TracedRArray{T,N};
3059+
epsilon,
3060+
feature_index::Int64,
3061+
location=mlir_stacktrace("batch_norm_grad", @__FILE__, @__LINE__),
3062+
) where {T,N}
3063+
len = size(operand, feature_index)
3064+
@assert length(mean) == length(variance) == len
3065+
@assert size(grad_output) == size(operand)
3066+
3067+
has_affine = scale !== nothing
3068+
3069+
if !has_affine
3070+
scale = fill(T(1), len; location)
3071+
else
3072+
@assert size(scale) == (len,)
3073+
end
3074+
3075+
batch_norm_grad_op = stablehlo.batch_norm_grad(
3076+
operand.mlir_data,
3077+
scale.mlir_data,
3078+
mean.mlir_data,
3079+
variance.mlir_data,
3080+
grad_output.mlir_data;
3081+
epsilon=Float32(epsilon),
3082+
feature_index=feature_index - 1,
3083+
location,
3084+
)
3085+
3086+
grad_operand = TracedRArray{T,N}(
3087+
(), MLIR.IR.result(batch_norm_grad_op, 1), size(operand)
3088+
)
3089+
grad_scale = TracedRArray{T,1}((), MLIR.IR.result(batch_norm_grad_op, 2), (len,))
3090+
grad_offset = TracedRArray{T,1}((), MLIR.IR.result(batch_norm_grad_op, 3), (len,))
3091+
3092+
return (
3093+
grad_operand, has_affine ? grad_scale : nothing, has_affine ? grad_offset : nothing
3094+
)
3095+
end
3096+
29713097
end # module Ops

test/ops.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,3 +1192,104 @@ end
11921192
@test @jit(recon_from_lu(lu_ra)) @jit(apply_permutation(x_ra, perm))
11931193
end
11941194
end
1195+
1196+
@testset "batch norm" begin
1197+
@testset "training" begin
1198+
@testset for affine in [false, true]
1199+
x = Reactant.to_rarray(randn(2, 3, 4, 5))
1200+
if affine
1201+
scale = Reactant.to_rarray(randn(3))
1202+
offset = Reactant.to_rarray(randn(3))
1203+
else
1204+
scale, offset = nothing, nothing
1205+
end
1206+
1207+
hlo = @code_hlo Ops.batch_norm_training(
1208+
x, scale, offset; epsilon=1e-5, feature_index=2
1209+
)
1210+
@test occursin("stablehlo.batch_norm_training", repr(hlo))
1211+
1212+
if !affine
1213+
@test occursin(
1214+
"stablehlo.constant dense<0.000000e+00> : tensor<3xf64>", repr(hlo)
1215+
)
1216+
@test occursin(
1217+
"stablehlo.constant dense<1.000000e+00> : tensor<3xf64>", repr(hlo)
1218+
)
1219+
end
1220+
1221+
res, m, v = @jit Ops.batch_norm_training(
1222+
x, scale, offset; epsilon=1e-5, feature_index=2
1223+
)
1224+
@test size(res) == size(x)
1225+
@test size(m) == (3,)
1226+
@test size(v) == (3,)
1227+
end
1228+
end
1229+
1230+
@testset "inference" begin
1231+
@testset for affine in [false, true]
1232+
x = Reactant.to_rarray(randn(2, 3, 4, 5))
1233+
if affine
1234+
scale = Reactant.to_rarray(randn(3))
1235+
offset = Reactant.to_rarray(randn(3))
1236+
else
1237+
scale, offset = nothing, nothing
1238+
end
1239+
1240+
rm = Reactant.to_rarray(randn(3))
1241+
rv = Reactant.to_rarray(rand(3))
1242+
1243+
hlo = @code_hlo Ops.batch_norm_inference(
1244+
x, scale, offset, rm, rv; epsilon=1e-5, feature_index=2
1245+
)
1246+
@test occursin("stablehlo.batch_norm_inference", repr(hlo))
1247+
if !affine
1248+
@test occursin(
1249+
"stablehlo.constant dense<0.000000e+00> : tensor<3xf64>", repr(hlo)
1250+
)
1251+
@test occursin(
1252+
"stablehlo.constant dense<1.000000e+00> : tensor<3xf64>", repr(hlo)
1253+
)
1254+
end
1255+
1256+
res = @jit Ops.batch_norm_inference(
1257+
x, scale, offset, rm, rv; epsilon=1e-5, feature_index=2
1258+
)
1259+
@test size(res) == size(x)
1260+
end
1261+
end
1262+
1263+
@testset "batch_norm_grad" begin
1264+
@testset for affine in [false, true]
1265+
x = Reactant.to_rarray(randn(2, 3, 4, 5))
1266+
scale = affine ? Reactant.to_rarray(randn(3)) : nothing
1267+
rm = Reactant.to_rarray(randn(3))
1268+
rv = Reactant.to_rarray(rand(3))
1269+
gx = Reactant.to_rarray(randn(2, 3, 4, 5))
1270+
1271+
hlo = @code_hlo Ops.batch_norm_grad(
1272+
x, scale, rm, rv, gx; epsilon=1e-5, feature_index=2
1273+
)
1274+
@test occursin("stablehlo.batch_norm_grad", repr(hlo))
1275+
1276+
if !affine
1277+
@test occursin(
1278+
"stablehlo.constant dense<1.000000e+00> : tensor<3xf64>", repr(hlo)
1279+
)
1280+
end
1281+
1282+
gres, gscale, goffset = @jit Ops.batch_norm_grad(
1283+
x, scale, rm, rv, gx; epsilon=1e-5, feature_index=2
1284+
)
1285+
@test size(gres) == size(x)
1286+
if !affine
1287+
@test gscale === nothing
1288+
@test goffset === nothing
1289+
else
1290+
@test size(gscale) == (3,)
1291+
@test size(goffset) == (3,)
1292+
end
1293+
end
1294+
end
1295+
end

0 commit comments

Comments
 (0)