Skip to content

Commit 03b6945

Browse files
d0kGoogle-ML-Automation
authored andcommitted
Updates LLVM usage to match [b214ca82daee](llvm/llvm-project@b214ca82daee) PiperOrigin-RevId: 700689999
1 parent 7f14de0 commit 03b6945

File tree

1 file changed

+40
-40
lines changed
  • jaxlib/mosaic/dialect/tpu

1 file changed

+40
-40
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
4444
}
4545

4646
// TODO(b/369418606): Find out the way to verify vreg size.
47-
def TPU_Vreg : Type<IsVectorTypePred, "native-sized vreg", "::mlir::VectorType">;
47+
def TPU_Vreg : Type<IsVectorOfNonZeroRankTypePred, "native-sized vreg", "::mlir::VectorType">;
4848

4949
class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
5050
: TypeDef<TPU_Dialect, name, traits> {
@@ -179,8 +179,8 @@ def TPU_ReductionKindAttr
179179
}
180180

181181
def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> {
182-
let arguments = (ins AnyVector:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind);
183-
let results = (outs AnyVector:$output);
182+
let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind);
183+
let results = (outs AnyVectorOfNonZeroRank:$output);
184184
let assemblyFormat = [{
185185
$input attr-dict `:` type($input)
186186
}];
@@ -217,11 +217,11 @@ def TPU_LoadOp : TPU_Op<"load"> {
217217
// TODO(jevinjiang): migrate tpu.strided_store to general vector store op.
218218
def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> {
219219
let arguments = (ins
220-
AnyVector:$valueToStore,
220+
AnyVectorOfNonZeroRank:$valueToStore,
221221
AnyMemRef:$base,
222222
Variadic<Index>:$indices,
223223
DenseI32ArrayAttr:$strides,
224-
Optional<AnyVector>:$mask // Elementwise mask.
224+
Optional<AnyVectorOfNonZeroRank>:$mask // Elementwise mask.
225225
);
226226
let results = (outs);
227227
let assemblyFormat = [{
@@ -236,7 +236,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
236236
Variadic<Index>:$indices,
237237
DenseI32ArrayAttr:$strides
238238
);
239-
let results = (outs AnyVector:$result);
239+
let results = (outs AnyVectorOfNonZeroRank:$result);
240240
let assemblyFormat = [{
241241
$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
242242
}];
@@ -245,7 +245,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
245245

246246
def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
247247
let arguments = (ins
248-
AnyVector:$valueToStore,
248+
AnyVectorOfNonZeroRank:$valueToStore,
249249
AnyMemRef:$base,
250250
Variadic<Index>:$indices,
251251
DenseI32ArrayAttr:$strides
@@ -291,15 +291,15 @@ def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> {
291291
// TODO(jevinjiang): deprecate to use dynamic_rotate.
292292
def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
293293
let arguments = (ins
294-
AnyVector:$value,
294+
AnyVectorOfNonZeroRank:$value,
295295
SI32Attr:$amount,
296296
SI32Attr:$dimension,
297297
// When the stride is specified, the rotation amount for each index on the
298298
// stride dimension will be (amount + stride * index).
299299
OptionalAttr<SI32Attr>:$stride,
300300
OptionalAttr<SI32Attr>:$stride_dimension
301301
);
302-
let results = (outs AnyVector:$result);
302+
let results = (outs AnyVectorOfNonZeroRank:$result);
303303
let assemblyFormat = [{
304304
$value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value)
305305
}];
@@ -308,15 +308,15 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
308308

309309
def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> {
310310
let arguments = (ins
311-
AnyVector:$value,
311+
AnyVectorOfNonZeroRank:$value,
312312
I32:$amount,
313313
SI32Attr:$dimension,
314314
// When the stride is specified, the rotation amount for each index on the
315315
// stride dimension will be (amount + stride * index).
316316
OptionalAttr<SI32Attr>:$stride,
317317
OptionalAttr<SI32Attr>:$stride_dimension
318318
);
319-
let results = (outs AnyVector:$result);
319+
let results = (outs AnyVectorOfNonZeroRank:$result);
320320
let assemblyFormat = [{
321321
$value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result)
322322
}];
@@ -325,30 +325,30 @@ def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> {
325325

326326
def TPU_IotaOp : TPU_Op<"iota", [Pure]> {
327327
let arguments = (ins OptionalAttr<I32Attr>:$dimension);
328-
let results = (outs AnyVector:$output);
328+
let results = (outs AnyVectorOfNonZeroRank:$output);
329329
let assemblyFormat = [{ attr-dict `:` type($output) }];
330330
}
331331

332332
// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically.
333333
// b/376295711
334334
def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
335335
let arguments = (ins
336-
AnyVector:$source,
336+
AnyVectorOfNonZeroRank:$source,
337337
I32Attr:$dimension,
338338
I32Attr:$times
339339
);
340-
let results = (outs AnyVector:$output);
340+
let results = (outs AnyVectorOfNonZeroRank:$output);
341341
let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }];
342342
}
343343

344344
def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
345345
let arguments = (ins
346-
AnyVector:$source, // All sublanes should be equal.
346+
AnyVectorOfNonZeroRank:$source, // All sublanes should be equal.
347347
I32Attr:$lane // Coordinates of the first element to take.
348348
);
349349
// Output shape should be the same, except for position dim which contains
350350
// the newly inserted dimension.
351-
let results = (outs AnyVector:$output);
351+
let results = (outs AnyVectorOfNonZeroRank:$output);
352352
let assemblyFormat = [{
353353
$source `,` $lane attr-dict `:` type($source) `->` type($output)
354354
}];
@@ -357,30 +357,30 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
357357
// Integer unpacks are always signed at the moment.
358358
def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
359359
let arguments = (ins
360-
AnyVector:$source,
360+
AnyVectorOfNonZeroRank:$source,
361361
I32Attr:$index
362362
);
363-
let results = (outs AnyVector:$output);
363+
let results = (outs AnyVectorOfNonZeroRank:$output);
364364
let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];
365365
}
366366

367367
// Integer packs are always signed at the moment.
368368
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> {
369369
let arguments = (ins
370-
Variadic<AnyVector>:$sources,
370+
Variadic<AnyVectorOfNonZeroRank>:$sources,
371371
TPU_PackFormatEnum:$pack_format
372372
);
373-
let results = (outs AnyVector:$output);
373+
let results = (outs AnyVectorOfNonZeroRank:$output);
374374
let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
375375
}
376376

377377
def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
378378
let arguments = (ins
379-
AnyVector:$source,
379+
AnyVectorOfNonZeroRank:$source,
380380
DenseI32ArrayAttr:$indices,
381381
I32Attr:$dimension
382382
);
383-
let results = (outs AnyVector:$output);
383+
let results = (outs AnyVectorOfNonZeroRank:$output);
384384
let assemblyFormat = [{
385385
$source `[` $indices `]` `in` $dimension attr-dict
386386
`:` type($source) `->` type($output)
@@ -389,11 +389,11 @@ def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
389389

390390
def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> {
391391
let arguments = (ins
392-
AnyVector:$source,
393-
AnyVector:$indices, // If this is 2D, only the first row matters.
392+
AnyVectorOfNonZeroRank:$source,
393+
AnyVectorOfNonZeroRank:$indices, // If this is 2D, only the first row matters.
394394
I32Attr:$dimension
395395
);
396-
let results = (outs AnyVector:$output);
396+
let results = (outs AnyVectorOfNonZeroRank:$output);
397397
let assemblyFormat = [{
398398
$source `[` $indices `]` `in` $dimension attr-dict
399399
`:` type($source) `,` type($indices) `->` type($output)
@@ -424,9 +424,9 @@ def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension
424424
// TODO(apaszke): Think hard about precision
425425
def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
426426
let arguments = (ins
427-
AnyVector:$lhs,
428-
AnyVector:$rhs,
429-
AnyVector:$acc,
427+
AnyVectorOfNonZeroRank:$lhs,
428+
AnyVectorOfNonZeroRank:$rhs,
429+
AnyVectorOfNonZeroRank:$acc,
430430
// These flags are deprecated - if dimension_numbers are defined,
431431
// these flags are ignored. They will always be false after canonicalize.
432432
DefaultValuedAttr<BoolAttr, "false">:$transpose_lhs,
@@ -435,7 +435,7 @@ def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
435435
// NOTE: User-level optional, once canonicalized, always present.
436436
OptionalAttr<TPU_DotDimensionNumbersAttr>:$dimension_numbers
437437
);
438-
let results = (outs AnyVector:$result);
438+
let results = (outs AnyVectorOfNonZeroRank:$result);
439439
let assemblyFormat = [{
440440
$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)
441441
}];
@@ -445,19 +445,19 @@ def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
445445

446446
def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> {
447447
let arguments = (ins
448-
Variadic<AnyVector>:$sources,
448+
Variadic<AnyVectorOfNonZeroRank>:$sources,
449449
I32Attr:$dimension
450450
);
451-
let results = (outs AnyVector:$output);
451+
let results = (outs AnyVectorOfNonZeroRank:$output);
452452
let assemblyFormat = [{
453453
$sources `in` $dimension attr-dict `:` type($sources) `->` type($output)
454454
}];
455455
let hasVerifier = 1;
456456
}
457457

458458
def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> {
459-
let arguments = (ins AnyVector:$input);
460-
let results = (outs AnyVector:$output);
459+
let arguments = (ins AnyVectorOfNonZeroRank:$input);
460+
let results = (outs AnyVectorOfNonZeroRank:$output);
461461
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
462462
let hasVerifier = 1;
463463
}
@@ -469,16 +469,16 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
469469
}
470470

471471
def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
472-
let arguments = (ins Variadic<AnyVector>:$input);
473-
let results = (outs AnyVector:$output);
472+
let arguments = (ins Variadic<AnyVectorOfNonZeroRank>:$input);
473+
let results = (outs AnyVectorOfNonZeroRank:$output);
474474
let assemblyFormat = [{
475475
$input attr-dict `:` type($input) `->` type($output)
476476
}];
477477
}
478478

479479
def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> {
480-
let arguments = (ins AnyVector:$input);
481-
let results = (outs Variadic<AnyVector>:$output);
480+
let arguments = (ins AnyVectorOfNonZeroRank:$input);
481+
let results = (outs Variadic<AnyVectorOfNonZeroRank>:$output);
482482
let hasCanonicalizeMethod = 1;
483483
let assemblyFormat = [{
484484
$input attr-dict `:` type($input) `->` type($output)
@@ -722,8 +722,8 @@ def TPU_DelayOp : TPU_Op<"delay"> {
722722

723723
// Expands the granularity of mask to subelements.
724724
def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> {
725-
let arguments = (ins AnyVector:$input);
726-
let results = (outs AnyVector:$result);
725+
let arguments = (ins AnyVectorOfNonZeroRank:$input);
726+
let results = (outs AnyVectorOfNonZeroRank:$result);
727727
let assemblyFormat = [{
728728
$input attr-dict `:` type($input) `->` type($result)
729729
}];
@@ -749,7 +749,7 @@ def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> {
749749

750750
def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> {
751751
let arguments = (ins);
752-
let results = (outs AnyVector:$output);
752+
let results = (outs AnyVectorOfNonZeroRank:$output);
753753
}
754754

755755
def TPU_LogOp : TPU_Op<"log"> {

0 commit comments

Comments
 (0)