-
Notifications
You must be signed in to change notification settings - Fork 51
Description
Problem
The Metal ReLU example (examples/relu/relu_metal/relu.mm) creates compute command encoders directly on the command buffer:
id<MTLComputeCommandEncoder> computeEncoder =
[commandBuffer computeCommandEncoder];This bypasses PyTorch's MPS stream encoder lifecycle management (kernel coalescing). PyTorch's MPSStream maintains an internal _commandEncoder that it reuses across operations. Creating an encoder directly can conflict with this active encoder, causing:
[AGXG16XFamilyCommandBuffer tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:
failed assertion 'A command encoder is already encoding to this command buffer'
This crashes the process (not a catchable exception — it's a Metal framework assertion).
When it triggers
Any time a Metal kernel is called twice in sequence without an intervening operation that flushes PyTorch's encoder state (e.g., .cpu(), torch.mps.synchronize()). This is common in real workloads.
Fix
Use stream->commandEncoder() from PyTorch's MPSStream API instead:
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
dispatch_sync(stream->queue(), ^{
id<MTLComputeCommandEncoder> enc = stream->commandEncoder();
// ... encode work ...
// Do NOT call [enc endEncoding] — stream manages encoder lifecycle
});
stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);This properly integrates with PyTorch's encoder coalescing: it reuses the stream's active encoder (or creates one if needed), and synchronize() handles ending it via endKernelCoalescing().
Impact
Since the ReLU example serves as the reference implementation for all Metal kernels built with kernel-builder, any downstream kernel that follows this pattern will have the same bug. We discovered this while developing Metal kernels for rotary-embedding and fused-rms-norm — the crash is 100% reproducible when calling a kernel function twice in sequence.