Skip to content

Commit a89c44f

Browse files
committed
Merge remote-tracking branch 'origin/nvdlff-inspect-support' into pgadzinski/debugtools-cppqtensor
2 parents 32e2e05 + 5904a80 commit a89c44f

File tree

24 files changed

+136
-129
lines changed

24 files changed

+136
-129
lines changed

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
1111
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
1212
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
1313
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
14-
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
14+
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
1515
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
1616
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
1717
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py

qa/L1_pytorch_distributed_unittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
1111
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
1212
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
1313
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
14-
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
14+
# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential
1515
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py

transformer_engine/common/common.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,8 @@ struct is_fp8<fp8e5m2> : std::true_type {};
423423
size_t typeToSize(const DType type);
424424

425425
void CheckNoopTensor(const Tensor &t, const std::string &name);
426-
void CheckInputTensor(const Tensor &t, const std::string &name,
427-
bool check_scale_inv_alignment = false);
428-
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false,
429-
bool check_scale_inv_alignment = false);
426+
void CheckInputTensor(const Tensor &t, const std::string &name);
427+
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
430428

431429
bool is_fp8_dtype(const DType t);
432430

transformer_engine/common/swizzle/swizzle.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
210210
return;
211211
}
212212

213-
CheckInputTensor(*input, "scaling_factor_input", true);
214-
CheckInputTensor(*output, "scaling_factor_output", true);
213+
CheckInputTensor(*input, "scaling_factor_input");
214+
CheckInputTensor(*output, "scaling_factor_output");
215215

216216
auto& scaling_mode = input->scaling_mode;
217217

transformer_engine/common/transformer_engine.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
6565
}
6666
}
6767

68-
void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
68+
void CheckScaleTensorShape(const Tensor &t) {
6969
NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
7070
if (is_tensor_scaling(t.scaling_mode)) {
7171
// per-tensor scaling
@@ -80,7 +80,6 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
8080
}
8181
} else {
8282
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
83-
if (!check_scale_inv_alignment) return;
8483
// Need (4, 128) alignment even for e8 scaling factor
8584
auto block_alignment = std::vector<size_t>{128ul / typeToSize(t.scale_inv.dtype),
8685
4ul / typeToSize(t.scale_inv.dtype)};
@@ -111,7 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
111110
}
112111
}
113112

114-
void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_alignment) {
113+
void CheckInputTensor(const Tensor &t, const std::string &name) {
115114
const DType type = t.dtype();
116115
if (is_fp8_dtype(type)) {
117116
// FP8 input needs to have scale_inv
@@ -143,11 +142,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale
143142
}
144143
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");
145144

146-
CheckScaleTensorShape(t, check_scale_inv_alignment);
145+
CheckScaleTensorShape(t);
147146
}
148147

149-
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty,
150-
bool check_scale_inv_alignment) {
148+
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
151149
const DType type = t.dtype();
152150
if (is_fp8_dtype(type)) {
153151
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
@@ -189,7 +187,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
189187
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
190188
}
191189

192-
CheckScaleTensorShape(t, check_scale_inv_alignment);
190+
CheckScaleTensorShape(t);
193191
}
194192

195193
} // namespace transformer_engine

transformer_engine/pytorch/cpp_extensions/gemm.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,13 @@
1111
import transformer_engine_torch as tex
1212
from ..constants import TE_DType
1313
from ..utils import assert_dim_for_fp8_exec, get_sm_count
14-
from ..tensor.quantized_tensor import QuantizedTensor
15-
from ..tensor.float8_tensor import Float8Tensor, Float8TensorBase
16-
from ..tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
1714

1815
from ..tensor.quantized_tensor import Quantizer
16+
from ..tensor.float8_tensor import Float8Tensor
17+
from ..tensor.mxfp8_tensor import MXFP8Tensor
1918
from ..tensor._internal.float8_tensor_base import Float8TensorBase
2019
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
2120

22-
23-
24-
from ..tensor.quantized_tensor import (
25-
QuantizedTensor,
26-
Quantizer,
27-
prepare_for_saving,
28-
restore_from_saved,
29-
)
30-
31-
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
32-
3321
__all__ = [
3422
"general_gemm",
3523
"general_grouped_gemm",

transformer_engine/pytorch/csrc/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,9 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
223223
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
224224
}
225225

226+
int roundup(const int value, const int multiple) {
227+
assert(multiple > 0);
228+
return ((value + multiple - 1) / multiple) * multiple;
229+
}
230+
226231
} // namespace transformer_engine::pytorch

transformer_engine/pytorch/csrc/common.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,6 @@ class FP8TensorMeta {
5959
at::Tensor amax_history;
6060
};
6161

62-
// FP8TensorMeta for block scaling, this structure allows
63-
// indexing into it the same way (i.e. using FP8FwdTensors
64-
// and FP8BwdTensors) for both hopper and blackwell recipes.
65-
// TODO(ksivaman): check perf with this design; should be ok
66-
// since there are no amax reductions, or bulk amax/scale
67-
// updates for block scaling.
68-
class MXFP8TensorMeta {
69-
public:
70-
std::vector<at::Tensor> scale;
71-
std::vector<at::Tensor> scale_inv;
72-
std::vector<at::Tensor> amax_history;
73-
};
74-
7562
// Used as named indices on the `scale`, `scale_inv`,
7663
// and `amax` tensors in the `FP8TensorMeta` class.
7764
enum FP8FwdTensors {
@@ -265,6 +252,8 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
265252

266253
std::vector<size_t> convertShape(const NVTEShape& shape);
267254

255+
int roundup(const int value, const int multiple);
256+
268257
} // namespace transformer_engine::pytorch
269258

270259
namespace std {

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,6 @@ at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);
361361

362362
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);
363363

364-
at::Tensor pad_scale_inv(at::Tensor scale_inv, bool rowwise);
365-
366364
/***************************************************************************************************
367365
* Comm+GEMM Overlap Wrappers
368366
**************************************************************************************************/

transformer_engine/pytorch/csrc/extensions/quantizer.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,28 +174,34 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
174174
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
175175
auto last_dim = torch_shape.back();
176176

177+
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
178+
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
179+
" (got shape=", torch_shape, ")");
180+
177181
at::Tensor data;
178182
if (rowwise_usage) {
179183
if (rowwise_data.has_value()) {
180184
data = std::move(*rowwise_data);
181185
} else {
182186
data = at::empty(torch_shape, opts);
183187
}
184-
rowwise_scale_inv = at::empty({numel / last_dim, last_dim / MXFP8_BLOCK_SIZE}, opts);
188+
auto sinv0 = roundup(numel / last_dim, 128);
189+
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
190+
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
185191
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
186-
tensor.set_rowwise_scale_inv(
187-
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
188-
std::vector<size_t>{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE});
189-
} else {
192+
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
193+
std::vector<size_t>{sinv0, sinv1});
190194
}
195+
191196
if (columnwise_usage) {
197+
auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4);
198+
auto sinv1 = roundup(last_dim, 128);
192199
columnwise_data = at::empty(torch_shape, opts);
193-
columnwise_scale_inv = at::empty({numel / (last_dim * MXFP8_BLOCK_SIZE), last_dim}, opts);
200+
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
194201

195202
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
196-
tensor.set_columnwise_scale_inv(
197-
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
198-
std::vector<size_t>{numel / (last_dim * MXFP8_BLOCK_SIZE), last_dim});
203+
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
204+
std::vector<size_t>{sinv0, sinv1});
199205
}
200206
this->set_quantization_params(&tensor);
201207

0 commit comments

Comments
 (0)