Skip to content

Commit ae35e4c

Browse files
bugfix graph capture
1 parent 51fab36 commit ae35e4c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
110110
num_blocks = std::min(num_blocks, max_blocks);
111111

112112
float* block_amax = nullptr;
113-
NVTE_CHECK_CUDA(cudaMalloc(&block_amax, num_blocks * sizeof(float)));
113+
NVTE_CHECK_CUDA(cudaMallocAsync(&block_amax, num_blocks * sizeof(float), stream));
114114

115115
// Launch kernel
116116
switch (align) {
@@ -141,6 +141,7 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
141141

142142
// Check results
143143
NVTE_CHECK_CUDA(cudaGetLastError());
144+
NVTE_CHECK_CUDA(cudaFreeAsync(block_amax, stream));
144145
}
145146

146147
} // namespace

0 commit comments

Comments
 (0)