Skip to content

Commit 3c9de07

Browse files
Proper allocation of workspace
1 parent 6c3507d commit 3c9de07

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

transformer_engine/common/include/transformer_engine/recipe.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
8484
*/
8585
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
8686

87+
void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream);
88+
8789
/*! \brief Update an FP8 tensor's scale based on its amax.
8890
*
8991
* This is only supported for FP8 tensors with per-tensor scaling.

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,7 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
114114
constexpr size_t max_blocks = 65535;
115115
num_blocks = std::min(num_blocks, max_blocks);
116116

117-
const bool UseBlockAmax = (block_amax != nullptr);
118-
119-
if (UseBlockAmax) {
120-
NVTE_CHECK(block_capacity >= num_blocks);
121-
}
117+
const bool UseBlockAmax = (block_amax != nullptr) && (block_capacity >= num_blocks);
122118

123119
// Launch kernel
124120
switch (align) {
@@ -167,6 +163,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
167163
} // namespace transformer_engine
168164

169165
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
166+
nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream);
167+
}
168+
169+
void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
170170
NVTE_API_CALL(nvte_compute_amax);
171171
using namespace transformer_engine;
172172

@@ -202,12 +202,19 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
202202
to_string(output.amax.dtype), ")");
203203
CheckOutputTensor(output, "output_compute_amax", true);
204204

205-
// Interpret output.data as workspace if present
206-
float *block_amax = nullptr;
205+
// Optional workspace
206+
float* block_amax = nullptr;
207207
size_t block_capacity = 0;
208-
if (output.data.dptr != nullptr) {
209-
block_amax = reinterpret_cast<float*>(output.data.dptr);
210-
block_capacity = output.data.numel(); // #floats in workspace
208+
209+
if (workspace_ != nullptr) {
210+
auto &workspace = *reinterpret_cast<Tensor *>(workspace_);
211+
NVTE_CHECK(workspace.data.dptr != nullptr,
212+
"Workspace tensor for amax computation has no data");
213+
NVTE_CHECK(workspace.data.dtype == DType::kFloat32,
214+
"Workspace tensor for amax computation must be FP32, got dtype=",
215+
to_string(workspace.data.dtype));
216+
block_amax = reinterpret_cast<float*>(workspace.data.dptr);
217+
block_capacity = workspace.data.numel();
211218
}
212219

213220
// Compute amax

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,25 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
5252
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
5353
// my_quantizer here has to be a Float8CurrentScalingQuantizer
5454
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
55+
56+
// workspace for nvte_compute_amax_with_workspace
57+
const auto N = static_cast<size_t>(input_tensor.numel());
58+
constexpr size_t threads = 512; // FIXME: should match amax_kernel_threads
59+
constexpr size_t max_blocks_hw = 65535;
60+
61+
// Worst-case (nvec = 1) upper bound on number of blocks.
62+
size_t max_blocks = std::min(DIVUP(N, threads), max_blocks_hw);
63+
64+
// Allocate FP32 workspace for block-wise amax
65+
auto ws = at::empty({static_cast<long>(max_blocks)},
66+
tensor.options().dtype(at::kFloat));
67+
68+
TensorWrapper te_workspace = makeTransformerEngineTensor(ws);
69+
5570
NVTE_SCOPED_GIL_RELEASE({
56-
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
71+
nvte_compute_amax_with_workspace(te_input.data(), te_output.data(),
72+
te_workspace.data(),
73+
at::cuda::getCurrentCUDAStream());
5774
});
5875
// check if we need to do amax reudction (depending on model parallel configs)
5976
if (my_quantizer_cs->with_amax_reduction) {

transformer_engine/pytorch/csrc/extensions/recipe.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,24 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
3030
size_t max_blocks = std::min(DIVUP(static_cast<size_t>(N), threads),
3131
max_blocks_hw);
3232

33-
// Allocate workspace for the fake output tensor.
34-
// This will be the block_amax buffer.
33+
// Allocate workspace for the block_amax buffer.
3534
auto ws = at::empty({static_cast<long>(max_blocks)},
3635
tensor.options().dtype(at::kFloat));
3736

3837
std::vector<size_t> ws_shape{static_cast<size_t>(max_blocks)};
3938

4039
TensorWrapper fake_te_output(
41-
ws.data_ptr(), ws_shape,
40+
nullptr, te_input.shape(),
4241
DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
4342
amax.data_ptr<float>());
4443

45-
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
44+
TensorWrapper te_workspace(
45+
ws.data_ptr(), ws_shape,
46+
DType::kFloat32,
47+
nullptr
48+
);
49+
50+
nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), te_workspace.data(), at::cuda::getCurrentCUDAStream());
4651
}
4752

4853
void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,

0 commit comments

Comments
 (0)