Skip to content

Commit f4df557

Browse files
committed
perf(mod_arith): remove remainder ops if possible
1 parent 8b12336 commit f4df557

File tree

4 files changed

+38
-32
lines changed

4 files changed

+38
-32
lines changed

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,19 @@ BENCHMARK(BM_intt_mont_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
113113
} // namespace
114114
} // namespace zkir
115115

116+
// clang-format off
117+
// NOLINTBEGIN(whitespace/line_length)
116118
// Run on (14 X 24 MHz CPU s)
117119
// CPU Caches:
118-
// L1 Data 64 KiB
119-
// L1 Instruction 128 KiB
120-
// L2 Unified 4096 KiB (x14)
121-
// Load Average: 9.50, 8.31, 8.95
120+
// L1 Data 64 KiB
121+
// L1 Instruction 128 KiB
122+
// L2 Unified 4096 KiB (x14)
123+
// Load Average: 27.66, 13.59, 9.67
122124
// ------------------------------------------------------------------------------
123-
// Benchmark Time CPU Iterations
125+
// Benchmark Time CPU Iterations
124126
// ------------------------------------------------------------------------------
125-
// BM_ntt_benchmark 0.339 s 0.333 s 2
126-
// BM_intt_benchmark/iterations:1 0.501 s 0.493 s 1
127-
// BM_ntt_mont_benchmark 0.379 s 0.372 s 2
128-
// BM_intt_mont_benchmark/iterations:1 0.510 s 0.504 s 1
127+
// BM_ntt_benchmark 0.190 s 0.183 s 4
128+
// BM_intt_benchmark/iterations:1 0.381 s 0.368 s 1
129+
// BM_ntt_mont_benchmark 0.221 s 0.214 s 3
130+
// BM_intt_mont_benchmark/iterations:1 0.415 s 0.396 s 1
131+
// NOLINTEND()

tests/Dialect/ModArith/mod_arith_to_arith.mlir

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
// CHECK-SAME: () -> [[T:.*]] {
88
func.func @test_lower_constant() -> !mod_arith.int<3 : i5> {
99
// CHECK-NOT: mod_arith.constant
10-
// CHECK: %[[CVAL:.*]] = arith.constant 5 : [[T]]
11-
// CHECK: %[[CMOD:.*]] = arith.constant 3 : [[T]]
12-
// CHECK: %[[REMU:.*]] = arith.remui %[[CVAL]], %[[CMOD]] : [[T]]
13-
// CHECK: return %[[REMU]] : [[T]]
10+
// CHECK: %[[CVAL:.*]] = arith.constant 2 : [[T]]
11+
// CHECK: return %[[CVAL]] : [[T]]
1412
%res = mod_arith.constant 5: !mod_arith.int<3 : i5>
1513
return %res: !mod_arith.int<3 : i5>
1614
}
@@ -116,7 +114,9 @@ func.func @test_lower_add(%lhs : !Zp, %rhs : !Zp) -> !Zp {
116114
// CHECK-NOT: mod_arith.add
117115
// CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[T]]
118116
// CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[T]]
119-
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]]
117+
// CHECK: %[[IFGE:.*]] = arith.cmpi uge, %[[ADD]], %[[CMOD]] : [[T]]
118+
// CHECK: %[[SUB:.*]] = arith.subi %[[ADD]], %[[CMOD]] : [[T]]
119+
// CHECK: %[[REM:.*]] = arith.select %[[IFGE]], %[[SUB]], %[[ADD]] : [[T]]
120120
// CHECK: return %[[REM]] : [[T]]
121121
%res = mod_arith.add %lhs, %rhs : !Zp
122122
return %res : !Zp
@@ -128,7 +128,9 @@ func.func @test_lower_add_vec(%lhs : !Zpv, %rhs : !Zpv) -> !Zpv {
128128
// CHECK-NOT: mod_arith.add
129129
// CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[T]]
130130
// CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[T]]
131-
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]]
131+
// CHECK: %[[IFGE:.*]] = arith.cmpi uge, %[[ADD]], %[[CMOD]] : [[T]]
132+
// CHECK: %[[SUB:.*]] = arith.subi %[[ADD]], %[[CMOD]] : [[T]]
133+
// CHECK: %[[REM:.*]] = arith.select %[[IFGE]], %[[SUB]], %[[ADD]] : tensor<4xi1>, [[T]]
132134
// CHECK: return %[[REM]] : [[T]]
133135
%res = mod_arith.add %lhs, %rhs : !Zpv
134136
return %res : !Zpv
@@ -141,8 +143,9 @@ func.func @test_lower_sub(%lhs : !Zp, %rhs : !Zp) -> !Zp {
141143
// CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[T]]
142144
// CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[T]]
143145
// CHECK: %[[ADD:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[T]]
144-
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]]
145-
// CHECK: return %[[REM]] : [[T]]
146+
// CHECK: %[[IFGE:.*]] = arith.cmpi uge, %[[LHS]], %[[RHS]] : [[T]]
147+
// CHECK: %[[SELECT:.*]] = arith.select %[[IFGE]], %[[SUB]], %[[ADD]] : [[T]]
148+
// CHECK: return %[[SELECT]] : [[T]]
146149
%res = mod_arith.sub %lhs, %rhs : !Zp
147150
return %res : !Zp
148151
}
@@ -154,8 +157,9 @@ func.func @test_lower_sub_vec(%lhs : !Zpv, %rhs : !Zpv) -> !Zpv {
154157
// CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[T]]
155158
// CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[T]]
156159
// CHECK: %[[ADD:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[T]]
157-
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]]
158-
// CHECK: return %[[REM]] : [[T]]
160+
// CHECK: %[[IFGE:.*]] = arith.cmpi uge, %[[LHS]], %[[RHS]] : [[T]]
161+
// CHECK: %[[SELECT:.*]] = arith.select %[[IFGE]], %[[SUB]], %[[ADD]] : tensor<4xi1>, [[T]]
162+
// CHECK: return %[[SELECT]] : [[T]]
159163
%res = mod_arith.sub %lhs, %rhs : !Zpv
160164
return %res : !Zpv
161165
}
@@ -195,10 +199,8 @@ func.func @test_lower_mul_vec(%lhs : !Zpv, %rhs : !Zpv) -> !Zpv {
195199
func.func @test_lower_constant_tensor() -> !Zpv {
196200
// CHECK-NOT: mod_arith.constant
197201
// CHECK: %[[C0:.*]] = arith.constant 5 : [[INT:.*]]
198-
// CHECK: %[[C1:.*]] = arith.constant 65537 : [[INT]]
199-
// CHECK: %[[C2:.*]] = arith.remui %[[C0]], %[[C1]] : [[INT]]
200202
%c0 = mod_arith.constant 5: !Zp
201-
// CHECK: %[[RES:.*]] = tensor.from_elements %[[C2]], %[[C2]], %[[C2]], %[[C2]] : [[T]]
203+
// CHECK: %[[RES:.*]] = tensor.from_elements %[[C0]], %[[C0]], %[[C0]], %[[C0]] : [[T]]
202204
%res = tensor.from_elements %c0, %c0, %c0, %c0 : !Zpv
203205
// CHECK: return %[[RES]] : [[T]]
204206
return %res : !Zpv

zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ struct ConvertConstant : public OpConversionPattern<ConstantOp> {
123123
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
124124

125125
auto cval = b.create<arith::ConstantOp>(op.getLoc(), adaptor.getValue());
126-
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
127-
auto remu = b.create<arith::RemUIOp>(cval, cmod);
128-
rewriter.replaceOp(op, remu);
126+
rewriter.replaceOp(op, cval);
129127
return success();
130128
}
131129
};
@@ -398,9 +396,11 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
398396

399397
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
400398
auto add = b.create<arith::AddIOp>(adaptor.getLhs(), adaptor.getRhs());
401-
auto remu = b.create<arith::RemUIOp>(add, cmod);
399+
auto ifge = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, add, cmod);
400+
auto sub = b.create<arith::SubIOp>(add, cmod);
401+
auto select = b.create<arith::SelectOp>(ifge, sub, add);
402402

403-
rewriter.replaceOp(op, remu);
403+
rewriter.replaceOp(op, select);
404404
return success();
405405
}
406406
};
@@ -419,9 +419,11 @@ struct ConvertSub : public OpConversionPattern<SubOp> {
419419
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
420420
auto sub = b.create<arith::SubIOp>(adaptor.getLhs(), adaptor.getRhs());
421421
auto add = b.create<arith::AddIOp>(sub, cmod);
422-
auto remu = b.create<arith::RemUIOp>(add, cmod);
422+
auto ifge = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
423+
adaptor.getLhs(), adaptor.getRhs());
424+
auto select = b.create<arith::SelectOp>(ifge, sub, add);
423425

424-
rewriter.replaceOp(op, remu);
426+
rewriter.replaceOp(op, select);
425427
return success();
426428
}
427429
};

zkir/Dialect/ModArith/IR/ModArithDialect.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,10 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
178178
}
179179

180180
// zero-extend or truncate to the correct bitwidth
181-
parsedInt = parsedInt.zextOrTrunc(outputBitWidth);
181+
parsedInt = parsedInt.zextOrTrunc(outputBitWidth).urem(modulus);
182182
result.addAttribute(
183183
"value",
184-
IntegerAttr::get(IntegerType::get(parser.getContext(), outputBitWidth),
185-
parsedInt));
184+
IntegerAttr::get(modArithType.getModulus().getType(), parsedInt));
186185
result.addTypes(parsedType);
187186
return success();
188187
}

0 commit comments

Comments
 (0)