Skip to content

Commit 5f8b10c

Browse files
mrguentherGoogle-ML-Automation
authored andcommitted
Add a test for the Operand Upcaster HLO pass.
PiperOrigin-RevId: 738938419
1 parent 6516fe8 commit 5f8b10c

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

xla/hlo/transforms/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ lit_test_suite(
1313
# go/keep-sorted start
1414
"algebraic_simplifier.hlo",
1515
"cholesky_expander.hlo",
16+
"operand_upcaster.hlo",
1617
"optimization_barrier_expander.hlo",
1718
"rewrite_bf16_conv_to_onednn.hlo",
1819
"rng_bit_generator_expander.hlo",
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// NOTE: Assertions have been autogenerated by hlo/tools/generate_hlo_test_checks.py
2+
// RUN: hlo-opt %s --split-input-file --passes=operand_upcaster | FileCheck %s
3+
4+
// CHECK-LABEL: HloModule TestDot, entry_computation_layout={(s8[8]{0}, s16[8]{0})->s32[8]{0}}
5+
6+
// CHECK-LABEL: ENTRY %test_dot
7+
// CHECK-NEXT: %[[a:[^ ]+]] = s8[8]{0} parameter(0)
8+
// CHECK-NEXT: %[[convert:[^ ]+]] = s32[8]{0} convert(%[[a]])
9+
// CHECK-NEXT: %[[b:[^ ]+]] = s16[8]{0} parameter(1)
10+
// CHECK-NEXT: %[[convert_1:[^ ]+]] = s32[8]{0} convert(%[[b]])
11+
// CHECK-NEXT: ROOT %[[result:[^ ]+]] = s32[8]{0} dot(%[[convert]], %[[convert_1]]), lhs_contracting_dims={}, rhs_contracting_dims={}
12+
13+
HloModule TestDot
14+
15+
ENTRY test_dot {
16+
a = s8[8] parameter(0)
17+
b = s16[8] parameter(1)
18+
ROOT result = s32[8] dot(a, b)
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: HloModule TestDotPackedNibble, entry_computation_layout={(f16[8,16]{1,0}, f16[16,8]{1,0})->f32[8,8]{1,0}}
24+
25+
// CHECK-LABEL: ENTRY %test_dot_packed_nibble
26+
// CHECK-NEXT: %[[arg_0:[^ ]+]] = f16[8,16]{1,0} parameter(0)
27+
// CHECK-NEXT: %[[convert:[^ ]+]] = f32[8,16]{1,0} convert(%[[arg_0]])
28+
// CHECK-NEXT: %[[arg_1:[^ ]+]] = f16[16,8]{1,0} parameter(1)
29+
// CHECK-NEXT: %[[convert_1:[^ ]+]] = f32[16,8]{1,0} convert(%[[arg_1]])
30+
// CHECK-NEXT: ROOT %[[dot:[^ ]+]] = f32[8,8]{1,0} dot(%[[convert]], %[[convert_1]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={packed_nibble,default}
31+
32+
HloModule TestDotPackedNibble
33+
34+
ENTRY test_dot_packed_nibble {
35+
arg_0 = f16[8,16] parameter(0)
36+
arg_1 = f16[16,8] parameter(1)
37+
ROOT dot = f32[8,8] dot(arg_0, arg_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={packed_nibble,default}
38+
}
39+
40+
// -----
41+
42+
// CHECK-LABEL: HloModule TestConvolutionPackedNibble, entry_computation_layout={(s8[3,3,7,7]{3,2,1,0}, s8[5,11,11,7]{3,2,1,0})->s32[5,11,11,7]{3,2,1,0}}
43+
44+
// CHECK-LABEL: ENTRY %test_convolution_packed_nibble
45+
// CHECK-NEXT: %[[lhs:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} parameter(1)
46+
// CHECK-NEXT: %[[constant:[^ ]+]] = s8[] constant(4)
47+
// CHECK-NEXT: %[[broadcast:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} broadcast(%[[constant]]), dimensions={}
48+
// CHECK-NEXT: %[[shift_left:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} shift-left(%[[lhs]], %[[broadcast]])
49+
// CHECK-NEXT: %[[constant_1:[^ ]+]] = s8[] constant(4)
50+
// CHECK-NEXT: %[[broadcast_1:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} broadcast(%[[constant_1]]), dimensions={}
51+
// CHECK-NEXT: %[[shift_right_arithmetic:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} shift-right-arithmetic(%[[shift_left]], %[[broadcast_1]])
52+
// CHECK-NEXT: %[[convert:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convert(%[[shift_right_arithmetic]])
53+
// CHECK-NEXT: %[[rhs:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} parameter(0)
54+
// CHECK-NEXT: %[[constant_3:[^ ]+]] = s8[] constant(4)
55+
// CHECK-NEXT: %[[broadcast_3:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} broadcast(%[[constant_3]]), dimensions={}
56+
// CHECK-NEXT: %[[shift_left_1:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} shift-left(%[[rhs]], %[[broadcast_3]])
57+
// CHECK-NEXT: %[[constant_4:[^ ]+]] = s8[] constant(4)
58+
// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} broadcast(%[[constant_4]]), dimensions={}
59+
// CHECK-NEXT: %[[shift_right_arithmetic_2:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} shift-right-arithmetic(%[[shift_left_1]], %[[broadcast_4]])
60+
// CHECK-NEXT: %[[convert_2:[^ ]+]] = s32[3,3,7,7]{3,2,1,0} convert(%[[shift_right_arithmetic_2]])
61+
// CHECK-NEXT: %[[convolution_1:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convolution(%[[convert]], %[[convert_2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
62+
// CHECK-NEXT: %[[constant_2:[^ ]+]] = s8[] constant(4)
63+
// CHECK-NEXT: %[[broadcast_2:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} broadcast(%[[constant_2]]), dimensions={}
64+
// CHECK-NEXT: %[[shift_right_arithmetic_1:[^ ]+]] = s8[5,11,11,7]{3,2,1,0} shift-right-arithmetic(%[[lhs]], %[[broadcast_2]])
65+
// CHECK-NEXT: %[[convert_1:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convert(%[[shift_right_arithmetic_1]])
66+
// CHECK-NEXT: %[[constant_5:[^ ]+]] = s8[] constant(4)
67+
// CHECK-NEXT: %[[broadcast_5:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} broadcast(%[[constant_5]]), dimensions={}
68+
// CHECK-NEXT: %[[shift_right_arithmetic_3:[^ ]+]] = s8[3,3,7,7]{3,2,1,0} shift-right-arithmetic(%[[rhs]], %[[broadcast_5]])
69+
// CHECK-NEXT: %[[convert_3:[^ ]+]] = s32[3,3,7,7]{3,2,1,0} convert(%[[shift_right_arithmetic_3]])
70+
// CHECK-NEXT: %[[convolution_2:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} convolution(%[[convert_1]], %[[convert_3]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
71+
// CHECK-NEXT: ROOT %[[add:[^ ]+]] = s32[5,11,11,7]{3,2,1,0} add(%[[convolution_1]], %[[convolution_2]])
72+
73+
HloModule TestConvolutionPackedNibble
74+
75+
ENTRY test_convolution_packed_nibble {
76+
lhs = s8[5,11,11,7] parameter(1)
77+
rhs = s8[3,3,7,7] parameter(0)
78+
ROOT convolution = s32[5,11,11,7] convolution(lhs, rhs), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, operand_precision={PACKED_NIBBLE,PACKED_NIBBLE}
79+
}

0 commit comments

Comments
 (0)