@@ -736,15 +736,15 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
736736 LinearLayout ret =
737737 LinearLayout (std::move (bases), llvm::to_vector (sliceLL.getOutDimNames ()));
738738
739- // Match a hack in the legacy code that ensures that the number of registers
740- // matches getTotalElemsPerThread. Yup: We just removed all the zeros, now
741- // we're (maybe) adding some back. :)
742- //
743- // TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing
744- // legacy code, I think we can remove this .
745- int expectedNumRegisters =
746- triton::gpu::getTotalElemsPerThread ( RankedTensorType::get (
747- shape, IntegerType::get (ctx, 32 ) /* dummy type */ , * this )) ;
739+ // The semantic of the slice layout:
740+ // The threads of the parent layout which are distributed on the
741+ // sliced dim are squeezed to hold the same value of tensor redundantly.
742+ // Only the number of values of sizePerThreads[dim] of the parent are reduced
743+ // to the one. We need to fix up the number of registers in case we just
744+ // removed all zeros bases aggressively .
745+ auto sizePerThreads = triton::gpu::getSizePerThread ( getParent ());
746+ unsigned expectedNumRegisters =
747+ parentLL-> getInDimSize ( S ( " register " )) / sizePerThreads[ getDim ()] ;
748748 if (ret.getInDimSize (S (" register" )) != expectedNumRegisters) {
749749 int extraZeros = expectedNumRegisters / ret.getInDimSize (S (" register" ));
750750 // Our use of "dim0" here is arbitrary; because we're adding zeros, any
0 commit comments