@@ -52,8 +52,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
5252 });
5353 addTargetMaterialization ([&](OpBuilder &builder,
5454 UnrankedMemRefType resultType,
55- ValueRange inputs,
56- Location loc) -> Value {
55+ ValueRange inputs, Location loc) -> Value {
5756 return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
5857 .getResult (0 );
5958 });
@@ -158,7 +157,7 @@ struct ScalarStoreConverter : public OpConversionPattern<tts::ScatterOp> {
158157 }
159158};
160159
161- // Lowering an unstructured load op (gather) into a linalg.generic op
160+ // Lowering an unstructured load op (gather) into a linalg.generic op.
162161struct GatherConverter : public OpConversionPattern <tts::GatherOp> {
163162 using OpConversionPattern<tts::GatherOp>::OpConversionPattern;
164163
@@ -171,118 +170,109 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
171170 LogicalResult
172171 matchAndRewrite (tts::GatherOp gatherOp, OpAdaptor adaptor,
173172 ConversionPatternRewriter &rewriter) const override {
174-
175173 auto loc = gatherOp->getLoc ();
176174
177175 auto ptr = adaptor.getPtr ();
178176 auto offsetTensor = adaptor.getOffset ();
179177 auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType ());
180178
181- // This must be a scalar load, skip processing
179+ // This must be a scalar load, skip processing.
182180 if (!offsetType) {
183181 return failure ();
184182 }
185183
186- auto loadResultType =
184+ auto resultType =
187185 dyn_cast<RankedTensorType>(gatherOp.getResult ().getType ());
188186
189187 // Treat the base pointer (memref) as 1D because the offsets are all
190188 // relative to a single base pointer (already collapsed).
191- auto baseMemref = rewriter.create <memref::CastOp>(
192- loc,
193- MemRefType::get ({ShapedType::kDynamic },
194- loadResultType.getElementType ()),
195- ptr);
189+ auto baseMemref = rewriter
190+ .create <memref::CastOp>(
191+ loc,
192+ MemRefType::get ({ShapedType::kDynamic },
193+ resultType.getElementType ()),
194+ ptr)
195+ .getResult ();
196196
197197 auto baseTensor =
198198 rewriter
199199 .create <bufferization::ToTensorOp>(
200200 loc,
201201 RankedTensorType::get (
202202 SmallVector<int64_t >(1 , ShapedType::kDynamic ),
203- loadResultType .getElementType ()),
203+ resultType .getElementType ()),
204204 baseMemref, true /* restrict */ , false /* writable */ )
205205 .getResult ();
206206
207207 // The linalg.generic op should have the following inputs:
208- // - the offset tensor
209- // - an optional mask tensor if the load op contains mask
208+ // - the offset tensor.
209+ // - an optional mask tensor if the gather op contains mask.
210210 SmallVector<Value> inputs{offsetTensor};
211211
212212 if (gatherOp.getMask ()) {
213213 inputs.push_back (gatherOp.getMask ());
214214 }
215215
216- auto emptyTensor =
217- rewriter
218- .create <tensor::EmptyOp>(loc, loadResultType.getShape (),
219- loadResultType.getElementType ())
220- .getResult ();
216+ auto emptyTensor = rewriter
217+ .create <tensor::EmptyOp>(loc, resultType.getShape (),
218+ resultType.getElementType ())
219+ .getResult ();
221220
222- // Affine maps for the inputs and output
223- // If no mask is used, 2 affine maps are generated; one for the input offset
224- // tensor, the other for the output tensor.
225- // If mask is used, the first 2 maps are for the offset and mask tensors
226- // while the last map is for the output tensor.
221+ // Affine maps for the inputs and one additional output.
227222 SmallVector<AffineMap> affineMaps (
228- gatherOp.getMask () ? 3 : 2 ,
229- rewriter.getMultiDimIdentityMap (loadResultType.getRank ()));
223+ inputs.size () + 1 ,
224+ rewriter.getMultiDimIdentityMap (resultType.getRank ()));
225+
226+ // All iterator types are parallel.
227+ SmallVector<utils::IteratorType> iteratorTypes (
228+ resultType.getRank (), utils::IteratorType::parallel);
230229
231230 auto genericOp = rewriter.create <linalg::GenericOp>(
232- loc, SmallVector<Type>({loadResultType}), inputs,
233- ValueRange{emptyTensor}, affineMaps,
234- SmallVector<utils::IteratorType>(loadResultType.getRank (),
235- utils::IteratorType::parallel),
236- [&](OpBuilder &b, Location loc, ValueRange args) {
237- auto getValueAtIndex = [baseTensor](Value indexValue, Location loc,
238- OpBuilder &b) -> Value {
231+ loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps,
232+ iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
233+ auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc,
234+ Value index) -> Value {
239235 Value index0 =
240- b.create <arith::IndexCastOp>(loc, b.getIndexType (), indexValue );
236+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), index );
241237
242238 return b.create <tensor::ExtractOp>(loc, baseTensor,
243239 ValueRange{index0});
244240 };
245241
242+ auto offset = args[0 ];
243+
246244 if (!gatherOp.getMask ()) {
247245 // If there is no mask, simply extract the current element from the
248246 // base tensor and use it as the yield value.
249- auto loadValue = getValueAtIndex (args[ 0 ] , loc, rewriter );
250- rewriter .create <linalg::YieldOp>(loc, loadValue);
247+ auto loadValue = getValueAtIndex (rewriter , loc, offset );
248+ b .create <linalg::YieldOp>(loc, loadValue);
251249 } else {
252250 // If the mask value is truthy, the current element is loaded from
253- // the base tensor using its offset. Otherwise, if `other` is
251+ // the base memref using its offset. Otherwise, if `other` is
254252 // present, yield `other`. If `other` is not present, a default
255253 // value of 0 is used.
256254 auto mask = args[1 ];
257- auto ifOp = rewriter .create <scf::IfOp>(
255+ auto ifOp = b .create <scf::IfOp>(
258256 loc, mask,
259257 [&](OpBuilder &b, Location loc) {
260- // Truthy case, load from the index
261- auto loadValue = getValueAtIndex (args[ 0 ] , loc, b );
262- b.create <scf::YieldOp>(loc, loadValue );
258+ // Truthy case, load from the index.
259+ auto value = getValueAtIndex (b , loc, offset );
260+ b.create <scf::YieldOp>(loc, value );
263261 },
264262 [&](OpBuilder &b, Location loc) {
265- // Falsy case, yield `other` or 0 as the default value
263+ // Falsy case, yield `other` or 0 as the default value.
266264 if (gatherOp.getOther ()) {
267265 b.create <scf::YieldOp>(loc, gatherOp.getOther ());
268266 } else {
269- auto elemType = baseTensor.getType ().getElementType ();
270- Value extract;
271- if (isa<IntegerType>(elemType)) {
272- extract = rewriter.create <arith::ConstantOp>(
273- loc, b.getIntegerAttr (elemType, 0 ));
274- } else if (isa<FloatType>(elemType)) {
275- extract = rewriter.create <arith::ConstantOp>(
276- loc, b.getFloatAttr (elemType, 0 ));
277- } else {
278- elemType.dump ();
279- llvm_unreachable (" unexpected type" );
280- }
267+ auto elemType = resultType.getElementType ();
268+ auto zeroAttr = b.getZeroAttr (elemType);
269+ assert (zeroAttr && " unexpected element type" );
270+ Value extract = b.create <arith::ConstantOp>(loc, zeroAttr);
281271 b.create <scf::YieldOp>(loc, extract);
282272 }
283273 });
284274
285- rewriter .create <linalg::YieldOp>(loc, ifOp->getResult (0 ));
275+ b .create <linalg::YieldOp>(loc, ifOp->getResult (0 ));
286276 }
287277 });
288278
@@ -292,7 +282,7 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
292282 }
293283};
294284
295- // Lowering an unstructured store op (scatter) into an affine loop nest
285+ // Lowering an unstructured store op (scatter) into a linalg.generic op.
296286struct ScatterConverter : public OpConversionPattern <tts::ScatterOp> {
297287 using OpConversionPattern<tts::ScatterOp>::OpConversionPattern;
298288
@@ -309,57 +299,82 @@ struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {
309299
310300 auto ptr = adaptor.getPtr ();
311301 auto offsetTensor = adaptor.getOffset ();
302+ auto valueTensor = adaptor.getValue ();
312303 auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType ());
313304
314- // This must be a scalar store, skip processing
305+ // This must be a scalar store, skip processing.
315306 if (!offsetType) {
316307 return failure ();
317308 }
318309
319- auto resultType =
320- dyn_cast<RankedTensorType>(scatterOp.getValue ().getType ());
310+ auto valueType = dyn_cast<RankedTensorType>(scatterOp.getValue ().getType ());
321311
322- auto storeMemref = rewriter. create <memref::CastOp>(
323- loc,
324- MemRefType::get ({ShapedType:: kDynamic }, resultType. getElementType ()),
325- ptr);
326-
327- auto ip = rewriter. saveInsertionPoint ();
328-
329- SmallVector<Value> ivs;
330- for ( auto dim : resultType. getShape ()) {
331- auto ub =
332- rewriter. create <arith::ConstantOp>(loc, rewriter. getIndexAttr (dim));
333- auto forOp = rewriter. create <affine::AffineForOp>(loc, 0 , dim);
334- ivs. push_back (forOp. getInductionVar ());
335- rewriter. setInsertionPointToStart (forOp. getBody ());
336- }
312+ // Treat the base pointer (memref) as 1D because the offsets are all
313+ // relative to a single base pointer (already collapsed).
314+ auto baseMemref =
315+ rewriter
316+ . create <memref::CastOp>(loc,
317+ MemRefType::get ({ShapedType:: kDynamic },
318+ valueType. getElementType ()),
319+ ptr)
320+ . getResult ();
321+
322+ // The linalg.generic op should have the following inputs:
323+ // - the offset tensor.
324+ // - the value tensor.
325+ // - an optional mask tensor if the scatter op contains mask.
326+ SmallVector<Value> inputs{offsetTensor, valueTensor};
337327
338328 if (scatterOp.getMask ()) {
339- // Mask case, only store the value if the mask value at `ivs` is truthy
340- auto maskValue =
341- rewriter.create <tensor::ExtractOp>(loc, scatterOp.getMask (), ivs);
329+ inputs.push_back (scatterOp.getMask ());
330+ }
342331
343- auto ifOp = rewriter.create <scf::IfOp>(loc, maskValue,
344- false /* withElseRegion */ );
332+ // Affine maps for the inputs.
333+ SmallVector<AffineMap> affineMaps (
334+ inputs.size (), rewriter.getMultiDimIdentityMap (valueType.getRank ()));
345335
346- rewriter. setInsertionPointToStart (
347- &ifOp. getThenRegion (). getBlocks (). front ());
348- }
336+ // All iterator types are parallel.
337+ SmallVector<utils::IteratorType> iteratorTypes (
338+ valueType. getRank (), utils::IteratorType::parallel);
349339
350- // Generate ops to store the value at each index. Note that with masking,
351- // these ops are created in the `if` block generated above.
352- auto offsetValue =
353- rewriter.create <tensor::ExtractOp>(loc, offsetTensor, ivs);
354- auto storeValue =
355- rewriter.create <tensor::ExtractOp>(loc, scatterOp.getValue (), ivs);
356- Value storeIndex = rewriter.create <arith::IndexCastOp>(
357- loc, rewriter.getIndexType (), offsetValue);
358- rewriter.create <memref::StoreOp>(loc, storeValue, storeMemref, storeIndex);
340+ rewriter.setInsertionPoint (scatterOp);
341+
342+ auto genericOp = rewriter.create <linalg::GenericOp>(
343+ loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes,
344+ [&](OpBuilder &b, Location loc, ValueRange args) {
345+ auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc,
346+ Value index, Value value) {
347+ Value index0 =
348+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), index);
349+
350+ b.create <memref::StoreOp>(loc, value, baseMemref,
351+ ValueRange{index0});
352+ };
353+
354+ auto offset = args[0 ];
355+ auto value = args[1 ];
356+
357+ if (!scatterOp.getMask ()) {
358+ // If there is no mask, simply insert the current value to the
359+ // base memref using its offset.
360+ storeValueAtIndex (b, loc, offset, value);
361+ } else {
362+ // If the mask value is truthy, insert the current value to the
363+ // the base memref using its offset. Otherwise, noop.
364+ auto mask = args[2 ];
365+ auto ifOp =
366+ b.create <scf::IfOp>(loc, mask, [&](OpBuilder &b, Location loc) {
367+ // Truthy case, load from the index
368+ storeValueAtIndex (b, loc, offset, value);
369+ b.create <scf::YieldOp>(loc);
370+ });
371+ }
372+
373+ b.create <linalg::YieldOp>(loc);
374+ });
359375
360- // Finalize
361376 rewriter.eraseOp (scatterOp);
362- rewriter. restoreInsertionPoint (ip);
377+
363378 return success ();
364379 }
365380};
0 commit comments