Skip to content

Commit c4d19ca

Browse files
gflegarGoogle-ML-Automation
authored andcommitted
PiperOrigin-RevId: 702397897
1 parent dfa0dd7 commit c4d19ca

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

jaxlib/gpu/triton_kernels.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ KernelCall::KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1,
466466

467467
absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
468468
std::vector<void*> params;
469-
params.reserve(parameters_.size());
469+
// We need an additional parameter for the scratchpad buffer.
470+
params.reserve(parameters_.size() + 1);
470471
for (size_t i = 0; i < parameters_.size(); ++i) {
471472
const Parameter& param = parameters_[i];
472473
if (std::holds_alternative<Parameter::Array>(param.value)) {
@@ -492,6 +493,14 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
492493
param.value)));
493494
}
494495
}
496+
// Triton's kernel ABI expects an additional scratchpad global memory.
497+
// For now it is only used for on-device creation of TMA descriptors, which
498+
// we do not use yet, so we are just replacing this argument with a null
499+
// pointer.
500+
// TODO: b/381242007 - Allocate a proper buffer if we want to use
501+
// device-side TMA APIs.
502+
void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns.
503+
params.push_back(&scratch_ptr);
495504

496505
return kernel_.Launch(stream, grid_, params.data());
497506
}

0 commit comments

Comments
 (0)