-
Notifications
You must be signed in to change notification settings - Fork 652
Expand file tree
/
Copy pathpybind.cpp
More file actions
527 lines (493 loc) · 32.2 KB
/
pybind.cpp
File metadata and controls
527 lines (493 loc) · 32.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "pybind.h"
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <optional>
#include <vector>
#include "../common.h"
#include "../extensions.h"
#include "common.h"
namespace transformer_engine::pytorch {
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *Float8TensorStoragePythonClass = nullptr;
PyTypeObject *Float8QuantizerClass = nullptr;
PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorStoragePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
std::once_flag extension_init_flag;
void init_float8_extension() {
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Float8CurrentScalingQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer"));
Float8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor"));
auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage");
Float8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage"));
NVTE_CHECK(Float8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch Float8 extension.");
}
void init_mxfp8_extension() {
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
MXFP8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor"));
auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage");
MXFP8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorStorage"));
NVTE_CHECK(MXFP8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch MXFP8 extension.");
}
void init_float8blockwise_extension() {
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
"transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage");
Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer"));
Float8BlockwiseQTensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorStorage"));
Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor"));
NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorStoragePythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
}
void init_nvfp4_extensions() {
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
NVFP4TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor"));
auto nvfp4_base_module =
py::module_::import("transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage");
NVFP4TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorStorage"));
NVTE_CHECK(NVFP4TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch NVFP4 extension.");
}
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}
} // namespace transformer_engine::pytorch
#include "common/util/pybind_helper.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m)
m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"),
py::arg("output") = py::none(), py::arg("noop") = py::none());
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype"));
m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize,
"Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"));
m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)",
py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"),
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
/* GLU (sigmoid gate) */
m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"),
py::arg("quantizer"));
/* GELU and variants*/
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"),
py::arg("quantizer"));
/* ReLU and variants */
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"),
py::arg("quantizer"));
/* SwiGLU and variants */
m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* Backward of GLU */
m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of ReLU and variants */
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of SiLU and variants */
m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu,
"Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"),
py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* DBias + DAct fusions*/
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu,
"DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"),
py::arg("quantizer"));
// Permutation functions
m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_permute_bwd", transformer_engine::pytorch::moe_permute_bwd, "MOE permute BWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_fwd", transformer_engine::pytorch::moe_unpermute_fwd, "MOE unpermute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD",
py::call_guard<py::gil_scoped_release>());
// Softmax functions
m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward,
"Scaled Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward,
"Scaled Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward",
&transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward",
&transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward",
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward",
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
// Other granular functions
m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"),
py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm");
m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"),
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add,
"Fused backward of RMSNorm + add");
m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize,
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false);
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims,
"Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
"Compute absolute max value in tensor", py::arg("input"), py::arg("amax"),
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction",
&transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_compute_partial_amax",
&transformer_engine::pytorch::fp8_block_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"),
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_partial_cast",
&transformer_engine::pytorch::fp8_block_scaling_partial_cast,
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("mxfp8_scaling_compute_partial_amax",
&transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"),
py::arg("amax_rowwise"), py::arg("amax_colwise"), py::arg("rows"), py::arg("cols"),
py::arg("start_offset"), py::call_guard<py::gil_scoped_release>());
m.def("mxfp8_scaling_partial_cast", &transformer_engine::pytorch::mxfp8_scaling_partial_cast,
"Partial cast from master weights for fp8 mxfp8 scaling", py::arg("input"),
py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"),
py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"),
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
"Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>());
m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm,
"Convert tensor block scales into GEMM swizzled format");
// attention kernels
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
"Prepare QKV for Flash Attention", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd,
"Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &transformer_engine::pytorch::copy_to_kv_cache,
"Copy new KV tokens to KV cache", py::call_guard<py::gil_scoped_release>());
m.def("convert_thd_to_bshd", &transformer_engine::pytorch::convert_thd_to_bshd,
"Convert a tensor from THD to BSHD", py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &transformer_engine::pytorch::convert_bshd_to_thd,
"Convert a tesnor from BSHD to THD", py::call_guard<py::gil_scoped_release>());
// fused apply rope
m.def("fused_rope_forward", &transformer_engine::pytorch::fused_rope_forward,
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward,
"Fused Apply QKV RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward,
"Fused Apply QKV RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router
m.def("fused_topk_with_score_function_fwd",
&transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"),
py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"),
py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"),
"Fused topk softmax fwd");
m.def("fused_topk_with_score_function_bwd",
&transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"),
py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"),
py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"),
py::arg("scaling_factor"), py::arg("score_function"), "Fused topk softmax bwd");
m.def("fused_score_for_moe_aux_loss_fwd",
&transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"),
py::arg("topk"), py::arg("score_function"), "Fused topk softmax fwd");
m.def("fused_score_for_moe_aux_loss_bwd",
&transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"),
py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"),
py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd");
m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd,
py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"),
py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"),
py::arg("coeff"), "Fused aux loss fwd");
m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd,
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"),
py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd");
// Dropout
m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG",
py::arg("input"), py::arg("dropout_probability"), py::arg("out") = std::nullopt);
m.def("dropout_bwd", transformer_engine::pytorch::dropout_bwd, "Dropout backward with 8-bit RNG",
py::arg("grad_output"), py::arg("mask"), py::arg("dropout_probability"),
py::arg("grad_input") = std::nullopt);
// Misc
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams",
py::call_guard<py::gil_scoped_release>());
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor",
py::call_guard<py::gil_scoped_release>());
m.def("thd_second_half_lse_correction",
&transformer_engine::pytorch::thd_second_half_lse_correction,
"Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_read_second_half_lse", &transformer_engine::pytorch::thd_read_second_half_lse,
"Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_out_correction", &transformer_engine::pytorch::thd_out_correction,
"Correct the THD format output of context parallelism in forward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_grad_correction", &transformer_engine::pytorch::thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
// nvshmem functions
m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream",
&transformer_engine::pytorch::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream",
&transformer_engine::pytorch::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions
m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_unscale_l2norm",
&transformer_engine::pytorch::multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder",
&transformer_engine::pytorch::multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable",
&transformer_engine::pytorch::multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable_master",
&transformer_engine::pytorch::multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_sgd", &transformer_engine::pytorch::multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_and_scale_inv",
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_inv_e8m0",
&transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda,
"Fused compute E8M0 scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm",
&transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm,
"Bulk overlap All-Gather with a GEMM operation launched by another communicator",
py::call_guard<py::gil_scoped_release>(), py::arg("allgather_communicator"),
py::arg("send_stream"), py::arg("recv_stream"));
// Data structures
py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
.def_readwrite("scale", &transformer_engine::pytorch::FP8TensorMeta::scale)
.def_readwrite("scale_inv", &transformer_engine::pytorch::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::pytorch::FP8TensorMeta::amax_history);
py::enum_<transformer_engine::pytorch::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_OUTPUT)
.value("GEMM3_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_INPUT)
.value("GEMM3_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_WEIGHT)
.value("GEMM3_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_OUTPUT);
py::enum_<transformer_engine::pytorch::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT2)
.value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3);
py::class_<CommOverlapHelper>(m, "CommOverlapHelper")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, std::optional<c10d::ProcessGroup *>>(),
py::call_guard<py::gil_scoped_release>(), py::arg("world_group"),
py::arg("intra_node_group") = py::none());
py::class_<CommOverlap, std::shared_ptr<CommOverlap>, transformer_engine::CommOverlapBase,
transformer_engine::CommOverlapCore>(m, "CommOverlap")
.def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int, int, int,
int, int, int, int, bool, bool, bool>(),
py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"),
py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS,
py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0,
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlap::*)(const at::Tensor &, bool)>(
&CommOverlap::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
m, "CommOverlapP2P")
.def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int,
transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool,
bool>(),
py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"),
py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1,
py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1,
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlapP2P::*)(const at::Tensor &, bool)>(
&CommOverlapP2P::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
}