Skip to content

Commit ea2e7f0

Browse files
authored
[Benchmarking] update distribution for i8 type (#1072)
Changing the `distribution` type from `255` to `127` to support random initialization for `i8` type.
1 parent b4efad4 commit ea2e7f0

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

include/TPP/Transforms/Utils/TensorInitInt.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,24 +94,31 @@ struct SimpleTensorInitInt : TensorInitInt {
9494

9595
// Continuous init (quantized normalized affine range).
9696
struct ContinuousTensorInitInt : TensorInitInt {
97-
ContinuousTensorInitInt(DataType type) : TensorInitInt(type) {}
97+
ContinuousTensorInitInt(DataType type)
98+
: TensorInitInt(type), upperBound(255) {
99+
if (type == DataType::I8)
100+
upperBound = 127;
101+
}
98102

99103
// Return a dense<0 ... upperBound> throughout the shape.
100104
void fillData() override;
101105

102106
// Upper bound for quantization.
103-
int upperBound = 255;
107+
int upperBound;
104108
};
105109

106110
// Random init (uniform).
107111
struct RandomTensorInitInt : TensorInitInt {
108112
RandomTensorInitInt(DataType type, int seed)
109-
: TensorInitInt(type), generator(seed), distribution(0, 255) {}
113+
: TensorInitInt(type), generator(seed), distribution(0, 255) {
114+
if (type == DataType::I8)
115+
distribution = std::uniform_int_distribution<uint64_t>(0, 127);
116+
}
110117

111118
// Next random uniform number.
112119
float next() { return distribution(generator); }
113120

114-
// Return a dense<uniform(0, 255)> throughout the shape.
121+
// Return a dense<uniform(0, distribution)> throughout the shape.
115122
void fillData() override;
116123

117124
private:
@@ -124,15 +131,18 @@ struct RandomTensorInitInt : TensorInitInt {
124131
// Random init (normal).
125132
struct NormalTensorInitInt : TensorInitInt {
126133
NormalTensorInitInt(DataType type, int seed)
127-
: TensorInitInt(type), generator(seed), distribution(255, 0.5) {}
134+
: TensorInitInt(type), generator(seed), distribution(255) {
135+
if (type == DataType::I8)
136+
distribution = std::binomial_distribution<uint64_t>(127);
137+
}
128138

129139
// Next random number.
130140
float next() {
131141
auto value = distribution(generator);
132142
return value;
133143
}
134144

135-
// Return a dense<normal(0, 255)> throughout the shape.
145+
// Return a dense<normal(0, distribution)> throughout the shape.
136146
void fillData() override;
137147

138148
private:

test/Integration/tpp-run-splat-tensor.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
// RUN: tpp-run %s -e entry -entry-point-result=void -print-mlir=early -seed 123 -splat-to-random -init-type=normal 2>&1 | \
1818
// RUN: FileCheck %s --check-prefix=OPT-NORMAL
1919

20-
func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4x2xi32>, %arg3: tensor<4x2xf16>) {
20+
func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4x2xi32>, %arg3: tensor<4x2xf16>, %arg4: tensor<4x2xi8>) {
2121
%0 = arith.constant dense<1.0> : tensor<2x16xf32>
2222
%5 = arith.constant dense<1.0> : tensor<2x16xf64>
2323
%1 = arith.constant dense<2.0> : tensor<4x16xf32>
2424
%10 = arith.constant dense<2.0> : tensor<4x4xf32>
2525
%2 = arith.constant dense<0.0> : tensor<4x8xf32>
2626
%3 = arith.constant dense<[[0.0, 1.0],[2.0, 3.0]]> : tensor<2x2xf32>
27+
%13 = arith.constant dense<1> : tensor<4x2xi8>
2728
%4 = arith.constant dense<0> : tensor<4x8xi32>
2829
%6 = arith.constant dense<1> : tensor<4x8xi32>
2930
%11 = arith.constant dense<1> : tensor<4x8xi32>
@@ -39,13 +40,15 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
3940
// SPLAT-DAG: memref.global "private" @__wrapper_1 : memref<4x2xi32> = dense<1>
4041
// SPLAT-DAG: memref.global "private" @__wrapper_2 : memref<4x2xi32> = dense<1>
4142
// SPLAT-DAG: memref.global "private" @__wrapper_3 : memref<4x2xf16> = dense<1.000000e+00>
43+
// SPLAT-DAG: memref.global "private" @__wrapper_4 : memref<4x2xi8> = dense<1>
4244
// SPLAT-LABEL: @_entry
4345
// SPLAT: arith.constant dense<1.000000e+00> : tensor<2x16xf32>
4446
// SPLAT: arith.constant dense<1.000000e+00> : tensor<2x16xf64>
4547
// SPLAT: arith.constant dense<2.000000e+00> : tensor<4x16xf32>
4648
// SPLAT: arith.constant dense<2.000000e+00> : tensor<4x4xf32>
4749
// SPLAT: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
4850
// SPLAT: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}> : tensor<2x2xf32>
51+
// SPLAT: arith.constant dense<1> : tensor<4x2xi8>
4952
// SPLAT: arith.constant dense<0> : tensor<4x8xi32>
5053
// SPLAT: arith.constant dense<1> : tensor<4x8xi32>
5154
// SPLAT: arith.constant dense<1> : tensor<4x8xi32>
@@ -62,13 +65,15 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
6265
// RANDOM-DAG: memref.global "private" @__wrapper_1 : memref<4x2xi32> = dense<{{\[}}{{\[}}132, 126], [117, 123], [126, 121], [132, 133]]>
6366
// RANDOM-DAG: memref.global "private" @__wrapper_2 : memref<4x2xi32> = dense<{{\[}}{{\[}}129, 134], [129, 126], [141, 131], [138, 121]]>
6467
// RANDOM-DAG: memref.global "private" @__wrapper_3 : memref<4x2xf16> = dense<{{\[}}{{\[}}0.000000e+00, 1.303710e-01], [1.512450e-01, 1.063540e-02]
68+
// RANDOM-DAG: memref.global "private" @__wrapper_4 : memref<4x2xi8> = dense<{{\[}}{{\[}}67, 63], [56, 60], [62, 59], [60, 71]]>
6569
// RANDOM-LABEL: @_entry
6670
// RANDOM: arith.constant dense<1.000000e+00> : tensor<2x16xf32>
6771
// RANDOM: arith.constant dense<1.000000e+00> : tensor<2x16xf64>
6872
// RANDOM: arith.constant dense<2.000000e+00> : tensor<4x16xf32>
6973
// RANDOM: arith.constant dense<2.000000e+00> : tensor<4x4xf32>
7074
// RANDOM: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
7175
// RANDOM: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}> : tensor<2x2xf32>
76+
// RANDOM: arith.constant dense<1> : tensor<4x2xi8>
7277
// RANDOM: arith.constant dense<0> : tensor<4x8xi32>
7378
// RANDOM: arith.constant dense<1> : tensor<4x8xi32>
7479
// RANDOM: arith.constant dense<1> : tensor<4x8xi32>
@@ -85,6 +90,7 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
8590
// RANDOM-SPLAT-NOT: memref.global "private" @__wrapper_1 : memref<4x2xi32> = dense<1>
8691
// RANDOM-SPLAT-NOT: memref.global "private" @__wrapper_2 : memref<4x2xi32> = dense<1>
8792
// RANDOM-SPLAT-NOT: memref.global "private" @__wrapper_3 : memref<4x2xf16> = dense<1.000000e+00>
93+
// RANDOM-SPLAT-NOT: memref.global "private" @__wrapper_4 : memref<4x2xi8> = dense<1>
8894
// RANDOM-SPLAT-LABEL: @_entry
8995
// RANDOM-SPLAT-NOT: arith.constant dense<1.000000e+00> : tensor<2x16xf32>
9096
// RANDOM-SPLAT-NOT: arith.constant dense<1.000000e+00> : tensor<2x16xf64>
@@ -96,6 +102,9 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
96102
// RANDOM-SPLAT: arith.constant dense<{{\[}}{{\[}}0.0440550111, 0.221581057, 0.000000e+00{{.*}}: tensor<4x4xf32>
97103
// RANDOM-SPLAT: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
98104
// RANDOM-SPLAT: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}> : tensor<2x2xf32>
105+
// RANDOM-SPLAT-NOT: arith.constant dense<1> : tensor<4x2xi8>
106+
// RANDOM-SPLAT: arith.constant dense<{{\[}}{{\[}}67, 63{{.*}}> : tensor<4x2xi8>
107+
// RANDOM-SPLAT-NOT: arith.constant dense<1> : tensor<4x2xi8>
99108
// RANDOM-SPLAT: arith.constant dense<0> : tensor<4x8xi32>
100109
// RANDOM-SPLAT-NOT: arith.constant dense<1> : tensor<4x8xi32>
101110
// RANDOM-SPLAT-NOT: arith.constant dense<1> : tensor<4x8xi64>
@@ -116,6 +125,7 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
116125
// OPT-CONST: arith.constant dense<1.000000e+00> : tensor<4x4xf32>
117126
// OPT-CONST: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
118127
// OPT-CONST: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}>
128+
// OPT-CONST: arith.constant dense<1> : tensor<4x2xi8>
119129
// OPT-CONST: arith.constant dense<0> : tensor<4x8xi32>
120130
// OPT-CONST: arith.constant dense<1> : tensor<4x8xi32>
121131
// OPT-CONST: arith.constant dense<1> : tensor<4x8xi32>
@@ -131,6 +141,8 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
131141
// OPT-SIMPLE: arith.constant dense<{{.*}}3.000000e-01, 6.000000e-01, 0.899999976, {{.*}}> : tensor<4x4xf32>
132142
// OPT-SIMPLE: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
133143
// OPT-SIMPLE: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}>
144+
// OPT-SIMPLE-NOT: arith.constant dense<1> : tensor<4x2xi8>
145+
// OPT-SIMPLE: arith.constant dense<{{\[}}{{\[}}0, 1{{.*}}> : tensor<4x2xi8>
134146
// OPT-SIMPLE: arith.constant dense<0> : tensor<4x8xi32>
135147
// OPT-SIMPLE-NOT: arith.constant dense<1> : tensor<4x8xi32>
136148
// OPT-SIMPLE-NOT: arith.constant dense<1> : tensor<4x8xi64>
@@ -148,6 +160,8 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
148160
// OPT-CONT: arith.constant dense<{{.*}}0.000000e+00, 6.250000e-02, 1.250000e-01, {{.*}}> : tensor<4x4xf32>
149161
// OPT-CONT: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
150162
// OPT-CONT: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}>
163+
// OPT-CONT-NOT: arith.constant dense<1> : tensor<4x2xi8>
164+
// OPT-CONT: arith.constant dense<{{\[}}{{\[}}0, 15{{.*}}> : tensor<4x2xi8>
151165
// OPT-CONT: arith.constant dense<0> : tensor<4x8xi32>
152166
// OPT-CONT-NOT: arith.constant dense<1> : tensor<4x8xi32>
153167
// OPT-CONT-NOT: arith.constant dense<1> : tensor<4x8xi64>
@@ -165,6 +179,8 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
165179
// OPT-RANDOM: arith.constant dense<{{.*}}0.685934782, 0.505808651, 0.126024485, {{.*}}> : tensor<4x4xf32>
166180
// OPT-RANDOM: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
167181
// OPT-RANDOM: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}>
182+
// OPT-RANDOM: arith.constant dense<{{\[}}{{\[}}0, 22{{.*}}> : tensor<4x2xi8>
183+
// OPT-RANDOM-NOT: arith.constant dense<1> : tensor<4x2xi8>
168184
// OPT-RANDOM: arith.constant dense<0> : tensor<4x8xi32>
169185
// OPT-RANDOM-NOT: arith.constant dense<1> : tensor<4x8xi32>
170186
// OPT-RANDOM-NOT: arith.constant dense<1> : tensor<4x8xi64>
@@ -182,6 +198,8 @@ func.func @entry(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %arg2: tensor<4
182198
// OPT-NORMAL: arith.constant dense<{{.*}}0.0440550111, 0.221581057, 0.000000e+00, {{.*}}> : tensor<4x4xf32>
183199
// OPT-NORMAL: arith.constant dense<0.000000e+00> : tensor<4x8xf32>
184200
// OPT-NORMAL: arith.constant dense<{{.*}}0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00{{.*}}>
201+
// OPT-NORMAL-NOT: arith.constant dense<1> : tensor<4x2xi8>
202+
// OPT-NORMAL: arith.constant dense<{{\[}}{{\[}}67, 63{{.*}}> : tensor<4x2xi8>
185203
// OPT-NORMAL: arith.constant dense<0> : tensor<4x8xi32>
186204
// OPT-NORMAL-NOT: arith.constant dense<1> : tensor<4x8xi32>
187205
// OPT-NORMAL-NOT: arith.constant dense<1> : tensor<4x8xi64>

0 commit comments

Comments
 (0)