@@ -44,7 +44,7 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
4444}
4545
4646// TODO(b/369418606): Find out the way to verify vreg size.
47- def TPU_Vreg : Type<IsVectorTypePred , "native-sized vreg", "::mlir::VectorType">;
47+ def TPU_Vreg : Type<IsVectorOfNonZeroRankTypePred , "native-sized vreg", "::mlir::VectorType">;
4848
4949class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
5050 : TypeDef<TPU_Dialect, name, traits> {
@@ -179,8 +179,8 @@ def TPU_ReductionKindAttr
179179}
180180
181181def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> {
182- let arguments = (ins AnyVector :$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind);
183- let results = (outs AnyVector :$output);
182+ let arguments = (ins AnyVectorOfNonZeroRank :$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind);
183+ let results = (outs AnyVectorOfNonZeroRank :$output);
184184 let assemblyFormat = [{
185185 $input attr-dict `:` type($input)
186186 }];
@@ -217,11 +217,11 @@ def TPU_LoadOp : TPU_Op<"load"> {
217217// TODO(jevinjiang): migrate tpu.strided_store to general vector store op.
218218def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> {
219219 let arguments = (ins
220- AnyVector :$valueToStore,
220+ AnyVectorOfNonZeroRank :$valueToStore,
221221 AnyMemRef:$base,
222222 Variadic<Index>:$indices,
223223 DenseI32ArrayAttr:$strides,
224- Optional<AnyVector >:$mask // Elementwise mask.
224+ Optional<AnyVectorOfNonZeroRank >:$mask // Elementwise mask.
225225 );
226226 let results = (outs);
227227 let assemblyFormat = [{
@@ -236,7 +236,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
236236 Variadic<Index>:$indices,
237237 DenseI32ArrayAttr:$strides
238238 );
239- let results = (outs AnyVector :$result);
239+ let results = (outs AnyVectorOfNonZeroRank :$result);
240240 let assemblyFormat = [{
241241 $base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
242242 }];
@@ -245,7 +245,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
245245
246246def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
247247 let arguments = (ins
248- AnyVector :$valueToStore,
248+ AnyVectorOfNonZeroRank :$valueToStore,
249249 AnyMemRef:$base,
250250 Variadic<Index>:$indices,
251251 DenseI32ArrayAttr:$strides
@@ -291,15 +291,15 @@ def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> {
291291// TODO(jevinjiang): deprecate to use dynamic_rotate.
292292def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
293293 let arguments = (ins
294- AnyVector :$value,
294+ AnyVectorOfNonZeroRank :$value,
295295 SI32Attr:$amount,
296296 SI32Attr:$dimension,
297297 // When the stride is specified, the rotation amount for each index on the
298298 // stride dimension will be (amount + stride * index).
299299 OptionalAttr<SI32Attr>:$stride,
300300 OptionalAttr<SI32Attr>:$stride_dimension
301301 );
302- let results = (outs AnyVector :$result);
302+ let results = (outs AnyVectorOfNonZeroRank :$result);
303303 let assemblyFormat = [{
304304 $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value)
305305 }];
@@ -308,15 +308,15 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
308308
309309def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> {
310310 let arguments = (ins
311- AnyVector :$value,
311+ AnyVectorOfNonZeroRank :$value,
312312 I32:$amount,
313313 SI32Attr:$dimension,
314314 // When the stride is specified, the rotation amount for each index on the
315315 // stride dimension will be (amount + stride * index).
316316 OptionalAttr<SI32Attr>:$stride,
317317 OptionalAttr<SI32Attr>:$stride_dimension
318318 );
319- let results = (outs AnyVector :$result);
319+ let results = (outs AnyVectorOfNonZeroRank :$result);
320320 let assemblyFormat = [{
321321 $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result)
322322 }];
@@ -325,30 +325,30 @@ def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> {
325325
326326def TPU_IotaOp : TPU_Op<"iota", [Pure]> {
327327 let arguments = (ins OptionalAttr<I32Attr>:$dimension);
328- let results = (outs AnyVector :$output);
328+ let results = (outs AnyVectorOfNonZeroRank :$output);
329329 let assemblyFormat = [{ attr-dict `:` type($output) }];
330330}
331331
332332// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically.
333333// b/376295711
334334def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
335335 let arguments = (ins
336- AnyVector :$source,
336+ AnyVectorOfNonZeroRank :$source,
337337 I32Attr:$dimension,
338338 I32Attr:$times
339339 );
340- let results = (outs AnyVector :$output);
340+ let results = (outs AnyVectorOfNonZeroRank :$output);
341341 let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }];
342342}
343343
344344def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
345345 let arguments = (ins
346- AnyVector :$source, // All sublanes should be equal.
346+ AnyVectorOfNonZeroRank :$source, // All sublanes should be equal.
347347 I32Attr:$lane // Coordinates of the first element to take.
348348 );
349349 // Output shape should be the same, except for position dim which contains
350350 // the newly inserted dimension.
351- let results = (outs AnyVector :$output);
351+ let results = (outs AnyVectorOfNonZeroRank :$output);
352352 let assemblyFormat = [{
353353 $source `,` $lane attr-dict `:` type($source) `->` type($output)
354354 }];
@@ -357,30 +357,30 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
357357// Integer unpacks are always signed at the moment.
358358def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
359359 let arguments = (ins
360- AnyVector :$source,
360+ AnyVectorOfNonZeroRank :$source,
361361 I32Attr:$index
362362 );
363- let results = (outs AnyVector :$output);
363+ let results = (outs AnyVectorOfNonZeroRank :$output);
364364 let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];
365365}
366366
367367// Integer packs are always signed at the moment.
368368def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> {
369369 let arguments = (ins
370- Variadic<AnyVector >:$sources,
370+ Variadic<AnyVectorOfNonZeroRank >:$sources,
371371 TPU_PackFormatEnum:$pack_format
372372 );
373- let results = (outs AnyVector :$output);
373+ let results = (outs AnyVectorOfNonZeroRank :$output);
374374 let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
375375}
376376
377377def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
378378 let arguments = (ins
379- AnyVector :$source,
379+ AnyVectorOfNonZeroRank :$source,
380380 DenseI32ArrayAttr:$indices,
381381 I32Attr:$dimension
382382 );
383- let results = (outs AnyVector :$output);
383+ let results = (outs AnyVectorOfNonZeroRank :$output);
384384 let assemblyFormat = [{
385385 $source `[` $indices `]` `in` $dimension attr-dict
386386 `:` type($source) `->` type($output)
@@ -389,11 +389,11 @@ def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
389389
390390def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> {
391391 let arguments = (ins
392- AnyVector :$source,
393- AnyVector :$indices, // If this is 2D, only the first row matters.
392+ AnyVectorOfNonZeroRank :$source,
393+ AnyVectorOfNonZeroRank :$indices, // If this is 2D, only the first row matters.
394394 I32Attr:$dimension
395395 );
396- let results = (outs AnyVector :$output);
396+ let results = (outs AnyVectorOfNonZeroRank :$output);
397397 let assemblyFormat = [{
398398 $source `[` $indices `]` `in` $dimension attr-dict
399399 `:` type($source) `,` type($indices) `->` type($output)
@@ -424,9 +424,9 @@ def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension
424424// TODO(apaszke): Think hard about precision
425425def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
426426 let arguments = (ins
427- AnyVector :$lhs,
428- AnyVector :$rhs,
429- AnyVector :$acc,
427+ AnyVectorOfNonZeroRank :$lhs,
428+ AnyVectorOfNonZeroRank :$rhs,
429+ AnyVectorOfNonZeroRank :$acc,
430430 // These flags are deprecated - if dimension_numbers are defined,
431431 // these flags are ignored. They will always be false after canonicalize.
432432 DefaultValuedAttr<BoolAttr, "false">:$transpose_lhs,
@@ -435,7 +435,7 @@ def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
435435 // NOTE: User-level optional, once canonicalized, always present.
436436 OptionalAttr<TPU_DotDimensionNumbersAttr>:$dimension_numbers
437437 );
438- let results = (outs AnyVector :$result);
438+ let results = (outs AnyVectorOfNonZeroRank :$result);
439439 let assemblyFormat = [{
440440 $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)
441441 }];
@@ -445,19 +445,19 @@ def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
445445
446446def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> {
447447 let arguments = (ins
448- Variadic<AnyVector >:$sources,
448+ Variadic<AnyVectorOfNonZeroRank >:$sources,
449449 I32Attr:$dimension
450450 );
451- let results = (outs AnyVector :$output);
451+ let results = (outs AnyVectorOfNonZeroRank :$output);
452452 let assemblyFormat = [{
453453 $sources `in` $dimension attr-dict `:` type($sources) `->` type($output)
454454 }];
455455 let hasVerifier = 1;
456456}
457457
458458def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> {
459- let arguments = (ins AnyVector :$input);
460- let results = (outs AnyVector :$output);
459+ let arguments = (ins AnyVectorOfNonZeroRank :$input);
460+ let results = (outs AnyVectorOfNonZeroRank :$output);
461461 let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
462462 let hasVerifier = 1;
463463}
@@ -469,16 +469,16 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
469469}
470470
471471def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
472- let arguments = (ins Variadic<AnyVector >:$input);
473- let results = (outs AnyVector :$output);
472+ let arguments = (ins Variadic<AnyVectorOfNonZeroRank >:$input);
473+ let results = (outs AnyVectorOfNonZeroRank :$output);
474474 let assemblyFormat = [{
475475 $input attr-dict `:` type($input) `->` type($output)
476476 }];
477477}
478478
479479def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> {
480- let arguments = (ins AnyVector :$input);
481- let results = (outs Variadic<AnyVector >:$output);
480+ let arguments = (ins AnyVectorOfNonZeroRank :$input);
481+ let results = (outs Variadic<AnyVectorOfNonZeroRank >:$output);
482482 let hasCanonicalizeMethod = 1;
483483 let assemblyFormat = [{
484484 $input attr-dict `:` type($input) `->` type($output)
@@ -722,8 +722,8 @@ def TPU_DelayOp : TPU_Op<"delay"> {
722722
723723// Expands the granularity of mask to subelements.
724724def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> {
725- let arguments = (ins AnyVector :$input);
726- let results = (outs AnyVector :$result);
725+ let arguments = (ins AnyVectorOfNonZeroRank :$input);
726+ let results = (outs AnyVectorOfNonZeroRank :$result);
727727 let assemblyFormat = [{
728728 $input attr-dict `:` type($input) `->` type($result)
729729 }];
@@ -749,7 +749,7 @@ def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> {
749749
750750def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> {
751751 let arguments = (ins);
752- let results = (outs AnyVector :$output);
752+ let results = (outs AnyVectorOfNonZeroRank :$output);
753753}
754754
755755def TPU_LogOp : TPU_Op<"log"> {
0 commit comments