Skip to content

Commit c0d8e73

Browse files
outline workspace allocation
1 parent 77a68a7 commit c0d8e73

File tree

2 files changed

+60
-15
lines changed

2 files changed

+60
-15
lines changed

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax,
4747
}
4848
}
4949

50-
template <int nvec, bool aligned, typename InputType>
50+
template <int nvec, bool aligned, typename InputType, bool UseBlockAmax>
5151
__launch_bounds__(amax_kernel_threads) __global__
52-
void amax_kernel(const InputType *input, float* __restrict__ block_amax, const size_t N,
52+
void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N,
5353
const size_t num_aligned_elements) {
5454
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
5555
InputType max{0.f};
@@ -85,12 +85,17 @@ __launch_bounds__(amax_kernel_threads) __global__
8585
// Reduce amax over block
8686
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
8787
if (threadIdx.x == 0) {
88-
block_amax[blockIdx.x] = max;
88+
if constexpr (UseBlockAmax) {
89+
block_amax[blockIdx.x] = max;
90+
} else {
91+
atomicMaxFloat(amax, max);
92+
}
8993
}
9094
}
9195

9296
template <int nvec, typename InputType>
93-
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
97+
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax,
98+
size_t block_capacity, cudaStream_t stream) {
9499
// Zero out amax so we can update with atomic max
95100
(void)cudaMemsetAsync(amax, 0, sizeof(float), stream);
96101

@@ -109,28 +114,43 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
109114
constexpr size_t max_blocks = 65535;
110115
num_blocks = std::min(num_blocks, max_blocks);
111116

112-
float* block_amax = nullptr;
113-
NVTE_CHECK_CUDA(cudaMallocAsync(&block_amax, num_blocks * sizeof(float), stream));
117+
const bool UseBlockAmax = (block_amax != nullptr);
118+
119+
if (UseBlockAmax) {
120+
NVTE_CHECK(block_capacity >= num_blocks);
121+
}
114122

115123
// Launch kernel
116124
switch (align) {
117125
case Alignment::SAME_ALIGNED:
118-
amax_kernel<nvec, true, InputType>
119-
<<<num_blocks, threads, 0, stream>>>(input, block_amax, N, num_aligned_elements);
126+
// FIXME: this code is clumsy. Perhaps don't use the UseBlockAmax extra template argument
127+
if (UseBlockAmax)
128+
amax_kernel<nvec, true, InputType, true>
129+
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
130+
else
131+
amax_kernel<nvec, true, InputType, false>
132+
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
120133
break;
121134
case Alignment::SAME_UNALIGNED:
122-
amax_kernel<nvec, false, InputType>
123-
<<<num_blocks, threads, 0, stream>>>(input, block_amax, N, num_aligned_elements);
135+
if (UseBlockAmax)
136+
amax_kernel<nvec, false, InputType, true>
137+
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
138+
else
139+
amax_kernel<nvec, false, InputType, false>
140+
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
124141
break;
125142
case Alignment::DIFFERENT: {
126143
// This case is a logic error, since there is only one pointer (input)
127144
// in the alignment check. Still safe to process without vectorization.
128-
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, block_amax, N, N);
145+
if (UseBlockAmax)
146+
amax_kernel<1, true, InputType, true><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
147+
else
148+
amax_kernel<1, true, InputType, false><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
129149
break;
130150
}
131151
}
132152

133-
{
153+
if (UseBlockAmax) {
134154
constexpr int FINAL_REDUCE_THREADS = 256;
135155
dim3 fr_block(FINAL_REDUCE_THREADS);
136156
dim3 fr_grid(1);
@@ -141,7 +161,6 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
141161

142162
// Check results
143163
NVTE_CHECK_CUDA(cudaGetLastError());
144-
NVTE_CHECK_CUDA(cudaFreeAsync(block_amax, stream));
145164
}
146165

147166
} // namespace
@@ -183,11 +202,20 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
183202
to_string(output.amax.dtype), ")");
184203
CheckOutputTensor(output, "output_compute_amax", true);
185204

205+
// Interpret output.data as workspace if present
206+
float *block_amax = nullptr;
207+
size_t block_capacity = 0;
208+
if (output.data.dptr != nullptr) {
209+
block_amax = reinterpret_cast<float*>(output.data.dptr);
210+
block_capacity = output.data.numel(); // #floats in workspace
211+
}
212+
186213
// Compute amax
187214
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
188215
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
189216
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
190-
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
217+
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(), block_amax,
218+
block_capacity,
191219
stream);); // NOLINT(*)
192220
}
193221

transformer_engine/pytorch/csrc/extensions/recipe.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,25 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
2020

2121
TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
2222
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
23+
24+
// Compute an upper bound on the number of blocks for this input.
25+
const auto N = input_tensor.numel();
26+
constexpr size_t threads = 512; // FIXME: should grab amax_kernel_threads here
27+
constexpr size_t max_blocks_hw = 65535;
28+
29+
// Assume worst-case vectorization (nvec = 1) as an upper bound.
30+
size_t max_blocks = std::min(DIVUP(static_cast<size_t>(N), threads),
31+
max_blocks_hw);
32+
33+
// Allocate workspace for the fake output tensor.
34+
// This will be the block_amax buffer.
35+
auto ws = at::empty({static_cast<long>(max_blocks)},
36+
tensor.options().dtype(at::kFloat));
37+
38+
std::vector<size_t> ws_shape{static_cast<size_t>(max_blocks)};
39+
2340
TensorWrapper fake_te_output(
24-
nullptr, te_input.shape(),
41+
ws.data_ptr(), ws_shape,
2542
DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
2643
amax.data_ptr<float>());
2744

0 commit comments

Comments
 (0)