Skip to content

Commit 0e1d9fa

Browse files
authored
[JAX] Bug fix for distributed normalization (NVIDIA#1366)
* fix ctx.aval_out indexing for workspace * add cudnn init to prepare phase of norm custom calls * add thread_local for norm registry instance --------- Signed-off-by: Phuong Nguyen <[email protected]>
1 parent e4c99b0 commit 0e1d9fa

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

transformer_engine/common/normalization/common.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
287287

288288
class NormalizationPlanRegistry {
289289
public:
290-
// TODO thread-safe
291290
static NormalizationPlanRegistry& getInstance() {
292-
static NormalizationPlanRegistry instance;
291+
static thread_local NormalizationPlanRegistry instance;
293292
return instance;
294293
}
295294

transformer_engine/jax/cpp_extensions/normalization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
147147
batch_shape = out_shape[:-1]
148148
batch_size = reduce(operator.mul, x_shape) // hidden_size
149149

150-
wkspace_aval = ctx.avals_out[-2:]
150+
wkspace_aval = ctx.avals_out[-1]
151151

152152
out_types = [
153153
ir.RankedTensorType.get(out_shape, output_type),
@@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
441441

442442
sm_margin = get_backward_sm_margin()
443443

444-
wkspace_aval = ctx.avals_out[-4:]
444+
wkspace_aval = ctx.avals_out[-1]
445445
opaque = transformer_engine_jax.pack_norm_descriptor(
446446
batch_size,
447447
hidden_size,
@@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon):
650650
batch_shape = out_shape[:-1]
651651
batch_size = reduce(operator.mul, x_shape) // hidden_size
652652

653-
wkspace_aval = ctx.avals_out[-2:]
653+
wkspace_aval = ctx.avals_out[-1]
654654

655655
out_types = [
656656
ir.RankedTensorType.get(out_shape, x_type.element_type),
@@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
841841
hidden_size = reduce(operator.mul, g_shape)
842842
batch_size = reduce(operator.mul, x_shape) // hidden_size
843843

844-
wkspace_aval = ctx.avals_out[-3:]
844+
wkspace_aval = ctx.avals_out[-1]
845845

846846
out_types = [
847847
ir.RankedTensorType.get(x_shape, x_type.element_type),
@@ -1088,7 +1088,7 @@ def lowering(
10881088
batch_shape = out_shape[:-1]
10891089
batch_size = reduce(operator.mul, x_shape) // hidden_size
10901090

1091-
wkspace_aval = ctx.avals_out[-2:]
1091+
wkspace_aval = ctx.avals_out[-1]
10921092

10931093
out_types = [
10941094
ir.RankedTensorType.get(out_shape, ir_out_dtype),
@@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
13941394
batch_shape = out_shape[:-1]
13951395
batch_size = reduce(operator.mul, x_shape) // hidden_size
13961396

1397-
wkspace_aval = ctx.avals_out[-2:]
1397+
wkspace_aval = ctx.avals_out[-1]
13981398

13991399
out_types = [
14001400
ir.RankedTensorType.get(out_shape, ir_out_dtype),

transformer_engine/jax/csrc/extensions/pybind.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,24 @@ pybind11::dict Registrations() {
8383
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
8484

8585
// Normalization
86-
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
87-
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
88-
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
89-
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
90-
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
91-
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
86+
dict["te_layernorm_forward_ffi"] =
87+
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
88+
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
89+
dict["te_layernorm_forward_fp8_ffi"] =
90+
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
91+
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
92+
dict["te_layernorm_backward_ffi"] =
93+
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
94+
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
95+
dict["te_rmsnorm_forward_ffi"] =
96+
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
97+
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
98+
dict["te_rmsnorm_forward_fp8_ffi"] =
99+
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
100+
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
101+
dict["te_rmsnorm_backward_ffi"] =
102+
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
103+
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));
92104

93105
// Attention
94106
pybind11::dict fused_attn_forward_ffi;

0 commit comments

Comments
 (0)