Skip to content

Commit 7cc8f48

Browse files
Address post review comment #4344 (#4348)
In the scope of `satisfies2DBlockReadAlignment`, it expects make tensor pointer op to be found. `satisfies2DBlockReadAlignment` is called at https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp#L74, and check of make tensor pointer is done earlier at https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp#L60. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent cb9a390 commit 7cc8f48

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -127,39 +127,41 @@ struct TritonIntelTensorDescToBlockPointer
127127
}
128128

129129
void propagateToLoops(Operation *op) {
130-
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(op)) {
131-
bool updated = false;
132-
for (auto [initArg, rgnInitArg, yieldVal, loopRes] :
133-
llvm::zip(loopOp.getInits(), loopOp.getRegionIterArgs(),
134-
loopOp.getYieldedValues(), loopOp->getResults())) {
135-
Type initArgType = initArg.getType();
136-
Type rgnInitArgType = rgnInitArg.getType();
137-
assert(rgnInitArgType == loopRes.getType() &&
138-
rgnInitArgType == yieldVal.getType() && "Type mismatch");
139-
if (rgnInitArgType != initArgType) {
140-
rgnInitArg.setType(initArgType);
141-
yieldVal.setType(initArgType);
142-
loopRes.setType(initArgType);
143-
updated = true;
144-
}
145-
}
146-
if (!updated)
147-
return;
148-
149-
// For while loops we also need to update the "after" region arguments.
150-
if (auto loopOp = dyn_cast<scf::WhileOp>(op)) {
151-
for (auto [initArg, rgnAfterArg] :
152-
llvm::zip(loopOp.getInits(), loopOp.getAfterArguments())) {
153-
Type initArgType = initArg.getType();
154-
if (rgnAfterArg.getType() != initArgType)
155-
rgnAfterArg.setType(initArgType);
156-
}
130+
auto loopOp = dyn_cast<LoopLikeOpInterface>(op);
131+
if (!loopOp)
132+
return;
133+
134+
bool updated = false;
135+
for (auto [initArg, rgnInitArg, yieldVal, loopRes] :
136+
llvm::zip(loopOp.getInits(), loopOp.getRegionIterArgs(),
137+
loopOp.getYieldedValues(), loopOp->getResults())) {
138+
Type initArgType = initArg.getType();
139+
Type rgnInitArgType = rgnInitArg.getType();
140+
assert(rgnInitArgType == loopRes.getType() &&
141+
rgnInitArgType == yieldVal.getType() && "Type mismatch");
142+
if (rgnInitArgType != initArgType) {
143+
rgnInitArg.setType(initArgType);
144+
yieldVal.setType(initArgType);
145+
loopRes.setType(initArgType);
146+
updated = true;
157147
}
148+
}
149+
if (!updated)
150+
return;
158151

159-
// Propagate the loop results to their users.
160-
for (Operation *user : loopOp->getUsers())
161-
propagateToLoops(user);
152+
// For while loops we also need to update the "after" region arguments.
153+
if (auto loopOp = dyn_cast<scf::WhileOp>(op)) {
154+
for (auto [initArg, rgnAfterArg] :
155+
llvm::zip(loopOp.getInits(), loopOp.getAfterArguments())) {
156+
Type initArgType = initArg.getType();
157+
if (rgnAfterArg.getType() != initArgType)
158+
rgnAfterArg.setType(initArgType);
159+
}
162160
}
161+
162+
// Propagate the loop results to their users.
163+
for (Operation *user : loopOp->getUsers())
164+
propagateToLoops(user);
163165
}
164166

165167
LogicalResult rewriteMakeTensorDescriptorOp(tt::MakeTensorDescOp op) {

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
295295
// operation.
296296
std::optional<tt::MakeTensorPtrOp> defOp =
297297
tt::intel::findDefiningMakeTensorPtrOp(ptr);
298-
if (!defOp)
299-
return false;
300-
298+
assert(defOp && "Expected a make tensor ptr op.");
301299
tt::MakeTensorPtrOp makeTensorPtrOp = *defOp;
302300
Operation::operand_range shape = makeTensorPtrOp.getShape();
303301
if (shape.size() == 1)

0 commit comments

Comments
 (0)