Skip to content

Commit 74b4d86

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
Add support for scratch buffers in jax_triton.
This is required to use device-side TMA descriptors. PiperOrigin-RevId: 735985603
1 parent ff751ec commit 74b4d86

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

jaxlib/gpu/triton_kernels.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -493,15 +493,7 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
493493
param.value)));
494494
}
495495
}
496-
// Triton's kernel ABI expects an additional scratchpad global memory.
497-
// For now it is only used for on-device creation of TMA descriptors, which
498-
// we do not use yet, so we are just replacing this argument with a null
499-
// pointer.
500-
// TODO: b/381242007 - Allocate a proper buffer if we want to use
501-
// device-side TMA APIs.
502-
void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns.
503-
params.push_back(&scratch_ptr);
504-
496+
params.push_back(buffers++); // Scratch buffer.
505497
return kernel_.Launch(stream, grid_, params.data());
506498
}
507499

0 commit comments

Comments
 (0)