@@ -104,55 +104,6 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
104104 threadsPerWarp, CTALayout);
105105 }
106106
107- static Type getNewType (Type type, Attribute encoding) {
108- RankedTensorType tensorType = cast<RankedTensorType>(type);
109- return RankedTensorType::get (tensorType.getShape (),
110- tensorType.getElementType (), encoding);
111- }
112-
113- void coalesceOp (Attribute encoding, Operation *op) {
114- OpBuilder builder (op);
115- // Convert operands
116- // For load/store with tensor pointers, we don't have to change the
117- // operands' type, we do this by changing the outputs' type of
118- // `make_tensor_ptr`
119- SmallVector<Value, 4 > newArgs;
120- for (auto operand : op->getOperands ()) {
121- auto tensorType = dyn_cast<RankedTensorType>(operand.getType ());
122- if (tensorType &&
123- !isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding ())) {
124- Type newType = getNewType (tensorType, encoding);
125- newArgs.push_back (builder.create <triton::gpu::ConvertLayoutOp>(
126- op->getLoc (), newType, operand));
127- } else {
128- newArgs.push_back (operand);
129- }
130- }
131-
132- // Convert output types
133- SmallVector<Type, 4 > newTypes;
134- for (auto t : op->getResultTypes ()) {
135- bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
136- newTypes.push_back (isAsync ? t : getNewType (t, encoding));
137- }
138-
139- // Construct new op with the new encoding
140- Operation *newOp =
141- builder.create (op->getLoc (), op->getName ().getIdentifier (), newArgs,
142- newTypes, op->getAttrs ());
143-
144- // Cast the results back to the original layout
145- for (size_t i = 0 ; i < op->getNumResults (); i++) {
146- Value newResult = newOp->getResult (i);
147- if (newTypes[i] != op->getResultTypes ()[i]) {
148- newResult = builder.create <triton::gpu::ConvertLayoutOp>(
149- op->getLoc (), op->getResult (i).getType (), newResult);
150- }
151- op->getResult (i).replaceAllUsesWith (newResult);
152- }
153- op->erase ();
154- }
155-
156107 void runOnOperation () override {
157108 // Run axis info analysis
158109 ModuleOp moduleOp = getOperation ();
@@ -187,7 +138,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
187138 // 4. Convert the output of this new memory op back to L1
188139 // 5. Replace all the uses of the original memory op by the new one
189140 for (auto &kv : layoutMap) {
190- coalesceOp (kv.second , kv.first );
141+ convertOpEncoding (kv.second , kv.first );
191142 }
192143 }
193144};
0 commit comments