@@ -2160,25 +2160,25 @@ def Vector_GatherOp :
21602160 ];
21612161}
21622162
2163- def Vector_ScatterOp :
2164- Vector_Op<"scatter", [
2165- DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2166- DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2167- ]> ,
2168- Arguments<(ins Arg<AnyMemRef, "", [MemWrite] >:$base ,
2169- Variadic< Index>:$offsets ,
2170- VectorOfNonZeroRankOf<[AnyInteger, Index ]>:$indices ,
2171- VectorOfNonZeroRankOf<[I1]>:$mask ,
2172- AnyVectorOfNonZeroRank:$valueToStore ,
2173- OptionalAttr<IntValidAlignment<I64Attr>>: $alignment )> {
2163+ def Vector_ScatterOp
2164+ : Vector_Op<"scatter",
2165+ [ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2166+ DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]>,
2167+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base ,
2168+ Variadic<Index >:$offsets ,
2169+ VectorOfNonZeroRankOf<[AnyInteger, Index] >:$indices ,
2170+ VectorOfNonZeroRankOf<[I1 ]>:$mask ,
2171+ AnyVectorOfNonZeroRank:$valueToStore ,
2172+ OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)> ,
2173+ Results<(outs Optional<AnyRankedTensor>:$result )> {
21742174
21752175 let summary = [{
2176- scatters elements from a vector into memory as defined by an index vector
2176+ scatters elements from a vector into memory or ranked tensor as defined by an index vector
21772177 and a mask vector
21782178 }];
21792179
21802180 let description = [{
2181- The scatter operation stores elements from a n-D vector into memory as
2181+ The scatter operation stores elements from a n-D vector into memory or ranked tensor as
21822182 defined by a base with indices and an additional n-D index vector, but
21832183 only if the corresponding bit in a n-D mask vector is set. Otherwise, no
21842184 action is taken for that element. Informally the semantics are:
@@ -2221,31 +2221,28 @@ def Vector_ScatterOp :
22212221 }];
22222222
22232223 let extraClassDeclaration = [{
2224- MemRefType getMemRefType () { return getBase().getType(); }
2224+ ShapedType getBaseType () { return getBase().getType(); }
22252225 VectorType getIndexVectorType() { return getIndices().getType(); }
22262226 VectorType getMaskVectorType() { return getMask().getType(); }
22272227 VectorType getVectorType() { return getValueToStore().getType(); }
22282228 }];
22292229
2230- let assemblyFormat =
2231- "$base `[ ` $offsets `]` `[` $indices `]` `,` "
2232- "$mask `,` $valueToStore attr-dict `: ` type($base ) `,` "
2233- "type($indices) `,` type($mask) `, ` type($valueToStore) ";
2230+ let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
2231+ "$mask `, ` $valueToStore attr-dict `:` type($base) `,` "
2232+ "type($indices) `, ` type($mask ) `,` "
2233+ " type($valueToStore) (`-> ` type($result)^)? ";
22342234 let hasCanonicalizer = 1;
22352235 let hasVerifier = 1;
22362236
2237- let builders = [
2238- OpBuilder<(ins "Value":$base,
2239- "ValueRange":$indices,
2240- "Value":$index_vec,
2241- "Value":$mask,
2242- "Value":$valueToStore,
2243- CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
2244- return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
2237+ let builders = [OpBuilder<
2238+ (ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
2239+ "Value":$index_vec, "Value":$mask, "Value":$valueToStore,
2240+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
2241+ [{
2242+ return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
22452243 alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
22462244 nullptr);
2247- }]>
2248- ];
2245+ }]>];
22492246}
22502247
22512248def Vector_ExpandLoadOp :
0 commit comments