@@ -52,7 +52,8 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
5252 });
5353 addTargetMaterialization ([&](OpBuilder &builder,
5454 UnrankedMemRefType resultType,
55- ValueRange inputs, Location loc) -> Value {
55+ ValueRange inputs,
56+ Location loc) -> Value {
5657 return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
5758 .getResult (0 );
5859 });
@@ -157,7 +158,7 @@ struct ScalarStoreConverter : public OpConversionPattern<tts::ScatterOp> {
157158 }
158159};
159160
160- // Lowering an unstructured load op (gather) into a linalg.generic op.
161+ // Lowering an unstructured load op (gather) into a linalg.generic op
161162struct GatherConverter : public OpConversionPattern <tts::GatherOp> {
162163 using OpConversionPattern<tts::GatherOp>::OpConversionPattern;
163164
@@ -170,109 +171,118 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
170171 LogicalResult
171172 matchAndRewrite (tts::GatherOp gatherOp, OpAdaptor adaptor,
172173 ConversionPatternRewriter &rewriter) const override {
174+
173175 auto loc = gatherOp->getLoc ();
174176
175177 auto ptr = adaptor.getPtr ();
176178 auto offsetTensor = adaptor.getOffset ();
177179 auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType ());
178180
179- // This must be a scalar load, skip processing.
181+ // This must be a scalar load, skip processing
180182 if (!offsetType) {
181183 return failure ();
182184 }
183185
184- auto resultType =
186+ auto loadResultType =
185187 dyn_cast<RankedTensorType>(gatherOp.getResult ().getType ());
186188
187189 // Treat the base pointer (memref) as 1D because the offsets are all
188190 // relative to a single base pointer (already collapsed).
189- auto baseMemref = rewriter
190- .create <memref::CastOp>(
191- loc,
192- MemRefType::get ({ShapedType::kDynamic },
193- resultType.getElementType ()),
194- ptr)
195- .getResult ();
191+ auto baseMemref = rewriter.create <memref::CastOp>(
192+ loc,
193+ MemRefType::get ({ShapedType::kDynamic },
194+ loadResultType.getElementType ()),
195+ ptr);
196196
197197 auto baseTensor =
198198 rewriter
199199 .create <bufferization::ToTensorOp>(
200200 loc,
201201 RankedTensorType::get (
202202 SmallVector<int64_t >(1 , ShapedType::kDynamic ),
203- resultType .getElementType ()),
203+ loadResultType .getElementType ()),
204204 baseMemref, true /* restrict */ , false /* writable */ )
205205 .getResult ();
206206
207207 // The linalg.generic op should have the following inputs:
208- // - the offset tensor.
209- // - an optional mask tensor if the gather op contains mask.
208+ // - the offset tensor
209+ // - an optional mask tensor if the load op contains mask
210210 SmallVector<Value> inputs{offsetTensor};
211211
212212 if (gatherOp.getMask ()) {
213213 inputs.push_back (gatherOp.getMask ());
214214 }
215215
216- auto emptyTensor = rewriter
217- .create <tensor::EmptyOp>(loc, resultType.getShape (),
218- resultType.getElementType ())
219- .getResult ();
216+ auto emptyTensor =
217+ rewriter
218+ .create <tensor::EmptyOp>(loc, loadResultType.getShape (),
219+ loadResultType.getElementType ())
220+ .getResult ();
220221
221- // Affine maps for the inputs and one additional output.
222+ // Affine maps for the inputs and output
223+ // If no mask is used, 2 affine maps are generated; one for the input offset
224+ // tensor, the other for the output tensor.
225+ // If mask is used, the first 2 maps are for the offset and mask tensors
226+ // while the last map is for the output tensor.
222227 SmallVector<AffineMap> affineMaps (
223- inputs.size () + 1 ,
224- rewriter.getMultiDimIdentityMap (resultType.getRank ()));
225-
226- // All iterator types are parallel.
227- SmallVector<utils::IteratorType> iteratorTypes (
228- resultType.getRank (), utils::IteratorType::parallel);
228+ gatherOp.getMask () ? 3 : 2 ,
229+ rewriter.getMultiDimIdentityMap (loadResultType.getRank ()));
229230
230231 auto genericOp = rewriter.create <linalg::GenericOp>(
231- loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps,
232- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
233- auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc,
234- Value index) -> Value {
232+ loc, SmallVector<Type>({loadResultType}), inputs,
233+ ValueRange{emptyTensor}, affineMaps,
234+ SmallVector<utils::IteratorType>(loadResultType.getRank (),
235+ utils::IteratorType::parallel),
236+ [&](OpBuilder &b, Location loc, ValueRange args) {
237+ auto getValueAtIndex = [baseTensor](Value indexValue, Location loc,
238+ OpBuilder &b) -> Value {
235239 Value index0 =
236- b.create <arith::IndexCastOp>(loc, b.getIndexType (), index );
240+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), indexValue );
237241
238242 return b.create <tensor::ExtractOp>(loc, baseTensor,
239243 ValueRange{index0});
240244 };
241245
242- auto offset = args[0 ];
243-
244246 if (!gatherOp.getMask ()) {
245247 // If there is no mask, simply extract the current element from the
246248 // base tensor and use it as the yield value.
247- auto loadValue = getValueAtIndex (b , loc, offset );
248- b .create <linalg::YieldOp>(loc, loadValue);
249+ auto loadValue = getValueAtIndex (args[ 0 ] , loc, rewriter );
250+ rewriter .create <linalg::YieldOp>(loc, loadValue);
249251 } else {
250252 // If the mask value is truthy, the current element is loaded from
251253 // the base tensor using its offset. Otherwise, if `other` is
252254 // present, yield `other`. If `other` is not present, a default
253255 // value of 0 is used.
254256 auto mask = args[1 ];
255- auto ifOp = b .create <scf::IfOp>(
257+ auto ifOp = rewriter .create <scf::IfOp>(
256258 loc, mask,
257259 [&](OpBuilder &b, Location loc) {
258- // Truthy case, load from the index.
259- auto value = getValueAtIndex (b , loc, offset );
260- b.create <scf::YieldOp>(loc, value );
260+ // Truthy case, load from the index
261+ auto loadValue = getValueAtIndex (args[ 0 ] , loc, b );
262+ b.create <scf::YieldOp>(loc, loadValue );
261263 },
262264 [&](OpBuilder &b, Location loc) {
263- // Falsy case, yield `other` or 0 as the default value.
265+ // Falsy case, yield `other` or 0 as the default value
264266 if (gatherOp.getOther ()) {
265267 b.create <scf::YieldOp>(loc, gatherOp.getOther ());
266268 } else {
267- auto elemType = resultType.getElementType ();
268- auto zeroAttr = b.getZeroAttr (elemType);
269- assert (zeroAttr && " unexpected element type" );
270- Value extract = b.create <arith::ConstantOp>(loc, zeroAttr);
269+ auto elemType = baseTensor.getType ().getElementType ();
270+ Value extract;
271+ if (isa<IntegerType>(elemType)) {
272+ extract = rewriter.create <arith::ConstantOp>(
273+ loc, b.getIntegerAttr (elemType, 0 ));
274+ } else if (isa<FloatType>(elemType)) {
275+ extract = rewriter.create <arith::ConstantOp>(
276+ loc, b.getFloatAttr (elemType, 0 ));
277+ } else {
278+ elemType.dump ();
279+ llvm_unreachable (" unexpected type" );
280+ }
271281 b.create <scf::YieldOp>(loc, extract);
272282 }
273283 });
274284
275- b .create <linalg::YieldOp>(loc, ifOp->getResult (0 ));
285+ rewriter .create <linalg::YieldOp>(loc, ifOp->getResult (0 ));
276286 }
277287 });
278288
@@ -282,7 +292,7 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
282292 }
283293};
284294
285- // Lowering an unstructured store op (scatter) into a linalg.generic op.
295+ // Lowering an unstructured store op (scatter) into an affine loop nest
286296struct ScatterConverter : public OpConversionPattern <tts::ScatterOp> {
287297 using OpConversionPattern<tts::ScatterOp>::OpConversionPattern;
288298
@@ -299,81 +309,57 @@ struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {
299309
300310 auto ptr = adaptor.getPtr ();
301311 auto offsetTensor = adaptor.getOffset ();
302- auto valueTensor = adaptor.getValue ();
303312 auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType ());
304313
305- // This must be a scalar store, skip processing.
314+ // This must be a scalar store, skip processing
306315 if (!offsetType) {
307316 return failure ();
308317 }
309318
310- auto valueType = dyn_cast<RankedTensorType>(scatterOp.getValue ().getType ());
311-
312- // Treat the base pointer (memref) as 1D because the offsets are all
313- // relative to a single base pointer (already collapsed).
314- auto baseMemref =
315- rewriter
316- .create <memref::CastOp>(loc,
317- MemRefType::get ({ShapedType::kDynamic },
318- valueType.getElementType ()),
319- ptr)
320- .getResult ();
321-
322- // The linalg.generic op should have the following inputs:
323- // - the offset tensor.
324- // - the value tensor.
325- // - an optional mask tensor if the scatter op contains mask.
326- SmallVector<Value> inputs{offsetTensor, valueTensor};
319+ auto resultType =
320+ dyn_cast<RankedTensorType>(scatterOp.getValue ().getType ());
327321
328- if (scatterOp.getMask ()) {
329- inputs.push_back (scatterOp.getMask ());
322+ auto storeMemref = rewriter.create <memref::CastOp>(
323+ loc,
324+ MemRefType::get ({ShapedType::kDynamic }, resultType.getElementType ()),
325+ ptr);
326+
327+ auto ip = rewriter.saveInsertionPoint ();
328+
329+ SmallVector<Value> ivs;
330+ for (auto dim : resultType.getShape ()) {
331+ auto ub =
332+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (dim));
333+ auto forOp = rewriter.create <affine::AffineForOp>(loc, 0 , dim);
334+ ivs.push_back (forOp.getInductionVar ());
335+ rewriter.setInsertionPointToStart (forOp.getBody ());
330336 }
331337
332- // Affine maps for the inputs.
333- SmallVector<AffineMap> affineMaps (
334- inputs.size (), rewriter.getMultiDimIdentityMap (valueType.getRank ()));
335-
336- // All iterator types are parallel.
337- SmallVector<utils::IteratorType> iteratorTypes (
338- valueType.getRank (), utils::IteratorType::parallel);
339-
340- rewriter.setInsertionPoint (scatterOp);
341-
342- auto genericOp = rewriter.create <linalg::GenericOp>(
343- loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes,
344- [&](OpBuilder &b, Location loc, ValueRange args) {
345- auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc,
346- Value index, Value value) {
347- Value index0 =
348- b.create <arith::IndexCastOp>(loc, b.getIndexType (), index);
349-
350- b.create <memref::StoreOp>(loc, value, baseMemref,
351- ValueRange{index0});
352- };
338+ if (scatterOp.getMask ()) {
339+ // Mask case, only store the value if the mask value at `ivs` is truthy
340+ auto maskValue =
341+ rewriter.create <tensor::ExtractOp>(loc, scatterOp.getMask (), ivs);
353342
354- auto offset = args[ 0 ];
355- auto value = args[ 1 ] ;
343+ auto ifOp = rewriter. create <scf::IfOp>(loc, maskValue,
344+ false /* withElseRegion */ ) ;
356345
357- if (!scatterOp.getMask ()) {
358- // If there is no mask, simply insert the current value to the
359- // base memref using its offset.
360- storeValueAtIndex (b, loc, offset, value);
361- } else {
362- // If the mask value is truthy, insert the current value to the
363- // the base memref using its offset. Otherwise, noop.
364- auto mask = args[2 ];
365- auto ifOp =
366- b.create <scf::IfOp>(loc, mask, [&](OpBuilder &b, Location loc) {
367- storeValueAtIndex (b, loc, offset, value);
368- b.create <scf::YieldOp>(loc);
369- });
370- }
346+ rewriter.setInsertionPointToStart (
347+ &ifOp.getThenRegion ().getBlocks ().front ());
348+ }
371349
372- b.create <linalg::YieldOp>(loc);
373- });
350+ // Generate ops to store the value at each index. Note that with masking,
351+ // these ops are created in the `if` block generated above.
352+ auto offsetValue =
353+ rewriter.create <tensor::ExtractOp>(loc, offsetTensor, ivs);
354+ auto storeValue =
355+ rewriter.create <tensor::ExtractOp>(loc, scatterOp.getValue (), ivs);
356+ Value storeIndex = rewriter.create <arith::IndexCastOp>(
357+ loc, rewriter.getIndexType (), offsetValue);
358+ rewriter.create <memref::StoreOp>(loc, storeValue, storeMemref, storeIndex);
374359
360+ // Finalize
375361 rewriter.eraseOp (scatterOp);
376-
362+ rewriter. restoreInsertionPoint (ip);
377363 return success ();
378364 }
379365};
0 commit comments