@@ -99,6 +99,12 @@ maybeUnpackOperands(OpBuilder &builder, Location loc, ValueRange operands,
9999 return std::make_pair (targets, SmallVector<Value>{});
100100}
101101
102+ static Value emitDiscriminate (OpBuilder &builder, Location loc, Value val) {
103+ if (isa<quake::MeasureType>(val.getType ()))
104+ return builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), val);
105+ return val;
106+ }
107+
102108namespace {
103109// Type used to specialize the buildOp function. This extends the cases below by
104110// prefixing a single parameter value to the list of arguments for cases 1
@@ -637,11 +643,7 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
637643 }
638644 case clang::CastKind::CK_IntegralToFloating: {
639645 auto value = popValue ();
640- // If source is `!quake.measure`, discriminate it first
641- if (isa<quake::MeasureType>(value.getType ())) {
642- value = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (),
643- value);
644- }
646+ value = emitDiscriminate (builder, loc, value);
645647 auto mode =
646648 (x->getSubExpr ()->getType ()->isUnsignedIntegerOrEnumerationType ())
647649 ? cudaq::cc::CastOpMode::Unsigned
@@ -651,24 +653,14 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
651653 }
652654 case clang::CastKind::CK_IntegralToBoolean: {
653655 auto last = popValue ();
654- // If the value is `!quake.measure`, discriminate it first
655- if (isa<quake::MeasureType>(last.getType ())) {
656- last =
657- builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), last);
658- return pushValue (last);
659- }
656+ last = emitDiscriminate (builder, loc, last);
660657 Value zero = builder.create <arith::ConstantIntOp>(loc, 0 , last.getType ());
661658 return pushValue (builder.create <arith::CmpIOp>(
662659 loc, arith::CmpIPredicate::ne, last, zero));
663660 }
664661 case clang::CastKind::CK_FloatingToBoolean: {
665662 auto last = popValue ();
666- // If the value is `!quake.measure`, discriminate it first
667- if (isa<quake::MeasureType>(last.getType ())) {
668- last =
669- builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), last);
670- return pushValue (last);
671- }
663+ last = emitDiscriminate (builder, loc, last);
672664 Value zero = opt::factory::createFloatConstant (
673665 loc, builder, 0.0 , cast<FloatType>(last.getType ()));
674666 return pushValue (builder.create <arith::CmpFOp>(
@@ -687,7 +679,7 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
687679 auto i1Type = builder.getI1Type ();
688680 // Handle conversion of `measure_result`
689681 if (isa<quake::MeasureType>(sub.getType ())) {
690- auto i1Val = builder. create <quake::DiscriminateOp>(loc, i1Type , sub);
682+ auto i1Val = emitDiscriminate (builder, loc , sub);
691683 // Convert to `int`
692684 if (isa<IntegerType>(castToTy))
693685 return pushValue (
@@ -860,12 +852,8 @@ bool QuakeBridgeVisitor::VisitBinaryOperator(clang::BinaryOperator *x) {
860852 rhs = maybeLoadValue (rhs);
861853 lhs = maybeLoadValue (lhs);
862854 // Discriminate measure types before comparison
863- if (isa<quake::MeasureType>(lhs.getType ()))
864- lhs =
865- builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), lhs);
866- if (isa<quake::MeasureType>(rhs.getType ()))
867- rhs =
868- builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), rhs);
855+ lhs = emitDiscriminate (builder, loc, lhs);
856+ rhs = emitDiscriminate (builder, loc, rhs);
869857 // Floating point comparison?
870858 if (isa<FloatType>(lhs.getType ())) {
871859 arith::CmpFPredicate pred;
@@ -945,10 +933,8 @@ bool QuakeBridgeVisitor::VisitBinaryOperator(clang::BinaryOperator *x) {
945933 rhs = maybeLoadValue (rhs);
946934 lhs = maybeLoadValue (lhs);
947935 // Discriminate measure types before arithmetic
948- if (isa<quake::MeasureType>(lhs.getType ()))
949- lhs = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), lhs);
950- if (isa<quake::MeasureType>(rhs.getType ()))
951- rhs = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), rhs);
936+ lhs = emitDiscriminate (builder, loc, lhs);
937+ rhs = emitDiscriminate (builder, loc, rhs);
952938 castToSameType (builder, loc, x->getLHS ()->getType ().getTypePtrOrNull (), lhs,
953939 x->getRHS ()->getType ().getTypePtrOrNull (), rhs);
954940 switch (x->getOpcode ()) {
@@ -1036,10 +1022,7 @@ bool QuakeBridgeVisitor::TraverseConditionalOperator(
10361022 if (!TraverseStmt (x->getCond ()))
10371023 return false ;
10381024 auto condVal = popValue ();
1039- // Discriminate if condition is `!quake.measure`
1040- if (isa<quake::MeasureType>(condVal.getType ()))
1041- condVal = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (),
1042- condVal);
1025+ condVal = emitDiscriminate (builder, loc, condVal);
10431026 Type resultTy = builder.getI64Type ();
10441027
10451028 // Create shared lambda for the x->getTrueExpr() and x->getFalseExpr()
@@ -1622,12 +1605,8 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
16221605 if (isa<cc::PointerType>(rhs.getType ()))
16231606 rhs = builder.create <cc::LoadOp>(loc, rhs);
16241607 // Discriminate measure types
1625- if (isa<quake::MeasureType>(lhs.getType ()))
1626- lhs = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (),
1627- lhs);
1628- if (isa<quake::MeasureType>(rhs.getType ()))
1629- rhs = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (),
1630- rhs);
1608+ lhs = emitDiscriminate (builder, loc, lhs);
1609+ rhs = emitDiscriminate (builder, loc, rhs);
16311610 // Choose predicate based on operator
16321611 auto pred = (opKind == clang::OO_EqualEqual) ? arith::CmpIPredicate::eq
16331612 : arith::CmpIPredicate::ne;
0 commit comments