@@ -636,21 +636,39 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
636636 builder.create <cudaq::cc::CastOp>(loc, castToTy, popValue (), mode));
637637 }
638638 case clang::CastKind::CK_IntegralToFloating: {
639+ auto value = popValue ();
640+ // If source is `!quake.measure`, discriminate it first
641+ if (isa<quake::MeasureType>(value.getType ())) {
642+ value = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (),
643+ value);
644+ }
639645 auto mode =
640646 (x->getSubExpr ()->getType ()->isUnsignedIntegerOrEnumerationType ())
641647 ? cudaq::cc::CastOpMode::Unsigned
642648 : cudaq::cc::CastOpMode::Signed;
643649 return pushValue (
644- builder.create <cudaq::cc::CastOp>(loc, castToTy, popValue () , mode));
650+ builder.create <cudaq::cc::CastOp>(loc, castToTy, value , mode));
645651 }
646652 case clang::CastKind::CK_IntegralToBoolean: {
647653 auto last = popValue ();
654+ // If the value is `!quake.measure`, discriminate it first
655+ if (isa<quake::MeasureType>(last.getType ())) {
656+ last =
657+ builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), last);
658+ return pushValue (last);
659+ }
648660 Value zero = builder.create <arith::ConstantIntOp>(loc, 0 , last.getType ());
649661 return pushValue (builder.create <arith::CmpIOp>(
650662 loc, arith::CmpIPredicate::ne, last, zero));
651663 }
652664 case clang::CastKind::CK_FloatingToBoolean: {
653665 auto last = popValue ();
666+ // If the value is `!quake.measure`, discriminate it first
667+ if (isa<quake::MeasureType>(last.getType ())) {
668+ last =
669+ builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), last);
670+ return pushValue (last);
671+ }
654672 Value zero = opt::factory::createFloatConstant (
655673 loc, builder, 0.0 , cast<FloatType>(last.getType ()));
656674 return pushValue (builder.create <arith::CmpFOp>(
@@ -667,10 +685,20 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
667685 return result;
668686 }
669687 auto i1Type = builder.getI1Type ();
670-
671- // Handle conversion of `measure_result` to `bool`.
672- if (isa<quake::MeasureType>(sub.getType ()))
673- return pushValue (builder.create <quake::DiscriminateOp>(loc, i1Type, sub));
688+ // Handle conversion of `measure_result`
689+ if (isa<quake::MeasureType>(sub.getType ())) {
690+ auto i1Val = builder.create <quake::DiscriminateOp>(loc, i1Type, sub);
691+ // Convert to `int`
692+ if (isa<IntegerType>(castToTy))
693+ return pushValue (
694+ builder.create <cudaq::cc::CastOp>(loc, castToTy, i1Val));
695+ // Convert to `float`
696+ if (isa<FloatType>(castToTy))
697+ return pushValue (builder.create <cudaq::cc::CastOp>(
698+ loc, castToTy, i1Val, cudaq::cc::CastOpMode::Unsigned));
699+ // Otherwise, just return the `i1` value
700+ return pushValue (i1Val);
701+ }
674702
675703 // Handle conversion of `std::vector<measure_result>` to `std::vector<bool>`
676704 if (auto vecTy = dyn_cast<cc::StdvecType>(sub.getType ()))
@@ -831,6 +859,13 @@ bool QuakeBridgeVisitor::VisitBinaryOperator(clang::BinaryOperator *x) {
831859 x->getOpcode () == clang::BinaryOperatorKind::BO_NE) {
832860 rhs = maybeLoadValue (rhs);
833861 lhs = maybeLoadValue (lhs);
862+ // Discriminate measure types before comparison
863+ if (isa<quake::MeasureType>(lhs.getType ()))
864+ lhs =
865+ builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), lhs);
866+ if (isa<quake::MeasureType>(rhs.getType ()))
867+ rhs =
868+ builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), rhs);
834869 // Floating point comparison?
835870 if (isa<FloatType>(lhs.getType ())) {
836871 arith::CmpFPredicate pred;
@@ -909,6 +944,11 @@ bool QuakeBridgeVisitor::VisitBinaryOperator(clang::BinaryOperator *x) {
909944 }
910945 rhs = maybeLoadValue (rhs);
911946 lhs = maybeLoadValue (lhs);
947+ // Discriminate measure types before arithmetic
948+ if (isa<quake::MeasureType>(lhs.getType ()))
949+ lhs = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), lhs);
950+ if (isa<quake::MeasureType>(rhs.getType ()))
951+ rhs = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (), rhs);
912952 castToSameType (builder, loc, x->getLHS ()->getType ().getTypePtrOrNull (), lhs,
913953 x->getRHS ()->getType ().getTypePtrOrNull (), rhs);
914954 switch (x->getOpcode ()) {
@@ -996,6 +1036,10 @@ bool QuakeBridgeVisitor::TraverseConditionalOperator(
9961036 if (!TraverseStmt (x->getCond ()))
9971037 return false ;
9981038 auto condVal = popValue ();
1039+ // Discriminate if condition is `!quake.measure`
1040+ if (isa<quake::MeasureType>(condVal.getType ()))
1041+ condVal = builder.create <quake::DiscriminateOp>(loc, builder.getI1Type (),
1042+ condVal);
9991043 Type resultTy = builder.getI64Type ();
10001044
10011045 // Create shared lambda for the x->getTrueExpr() and x->getFalseExpr()
@@ -2147,19 +2191,30 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
21472191 }
21482192
21492193 if (funcName == " toInteger" || funcName == " to_integer" ) {
2194+ auto arg = args[0 ];
2195+ // Insert discriminate if input is `!cc.stdvec<!quake.measure>`
2196+ if (auto vecTy = dyn_cast<cc::StdvecType>(arg.getType ())) {
2197+ if (isa<quake::MeasureType>(vecTy.getElementType ())) {
2198+ auto i1Ty = builder.getI1Type ();
2199+ arg = builder.create <quake::DiscriminateOp>(
2200+ loc, cc::StdvecType::get (i1Ty), arg);
2201+ }
2202+ }
21502203 IRBuilder irBuilder (builder.getContext ());
21512204 if (failed (irBuilder.loadIntrinsic (module , cudaqConvertToInteger))) {
21522205 reportClangError (x, mangler, " cannot load cudaqConvertToInteger" );
21532206 return false ;
21542207 }
21552208 auto i64Ty = builder.getI64Type ();
2156- return pushValue (
2157- builder.create <func::CallOp>(loc, i64Ty, cudaqConvertToInteger, args)
2158- .getResult (0 ));
2209+ return pushValue (builder
2210+ .create <func::CallOp>(loc, i64Ty,
2211+ cudaqConvertToInteger,
2212+ ValueRange{arg})
2213+ .getResult (0 ));
21592214 }
21602215
21612216 if (funcName == " to_bool_vector" ) {
2162- // args[0] is !cc.stdvec<!quake.measure> from mz()
2217+ // ` args[0]` is ` !cc.stdvec<!quake.measure>`
21632218 auto arg = args[0 ];
21642219 // Insert discriminate if needed
21652220 if (auto vecTy = dyn_cast<cc::StdvecType>(arg.getType ())) {
@@ -2169,15 +2224,7 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
21692224 loc, cc::StdvecType::get (i1Ty), arg);
21702225 }
21712226 }
2172- IRBuilder irBuilder (builder.getContext ());
2173- if (failed (irBuilder.loadIntrinsic (module , cudaqConvertToBoolVector))) {
2174- reportClangError (x, mangler, " cannot load cudaqConvertToBoolVector" );
2175- return false ;
2176- }
2177- return pushValue (builder
2178- .create <func::CallOp>(loc, arg.getType (),
2179- cudaqConvertToBoolVector, arg)
2180- .getResult (0 ));
2227+ return pushValue (arg);
21812228 }
21822229
21832230 if (funcName == " slice_vector" ) {
0 commit comments