Skip to content

Commit 119b4ef

Browse files
[naga wgsl-in] Short-circuiting of && and || operators (#7339)
Addresses parts of #4394 and #6302
1 parent 1f99103 commit 119b4ef

File tree

9 files changed

+1090
-563
lines changed

9 files changed

+1090
-563
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206
152152
- Validate that buffers are unmapped in `write_buffer` calls. By @ErichDonGubler in [#8454](https://github.com/gfx-rs/wgpu/pull/8454).
153153
- Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370).
154154

155+
#### naga
156+
157+
- The `||` and `&&` operators now "short circuit", i.e., do not evaluate the RHS if the result can be determined from just the LHS. By @andyleiserson in [#7339](https://github.com/gfx-rs/wgpu/pull/7339).
158+
155159
#### DX12
156160

157161
- Align copies b/w textures and buffers via a single intermediate buffer per copy when `D3D12_FEATURE_DATA_D3D12_OPTIONS13.UnrestrictedBufferTextureCopyPitchSupported` is `false`. By @ErichDonGubler in [#7721](https://github.com/gfx-rs/wgpu/pull/7721).

naga/src/front/wgsl/lower/mod.rs

Lines changed: 250 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,13 @@ impl TypeContext for ExpressionContext<'_, '_, '_> {
426426
}
427427

428428
impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
429+
const fn is_runtime(&self) -> bool {
430+
match self.expr_type {
431+
ExpressionContextType::Runtime(_) => true,
432+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => false,
433+
}
434+
}
435+
429436
#[allow(dead_code)]
430437
fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
431438
ExpressionContext {
@@ -588,6 +595,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
588595
}
589596
}
590597

598+
fn get(&self, handle: Handle<crate::Expression>) -> &crate::Expression {
599+
match self.expr_type {
600+
ExpressionContextType::Runtime(ref ctx)
601+
| ExpressionContextType::Constant(Some(ref ctx)) => &ctx.function.expressions[handle],
602+
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
603+
&self.module.global_expressions[handle]
604+
}
605+
}
606+
}
607+
591608
fn local(
592609
&mut self,
593610
local: &Handle<ast::Local>,
@@ -614,6 +631,52 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
614631
}
615632
}
616633

634+
fn with_nested_runtime_expression_ctx<'a, F, T>(
635+
&mut self,
636+
span: Span,
637+
f: F,
638+
) -> Result<'source, (T, crate::Block)>
639+
where
640+
for<'t> F: FnOnce(&mut ExpressionContext<'source, 't, 't>) -> Result<'source, T>,
641+
{
642+
let mut block = crate::Block::new();
643+
let rctx = match self.expr_type {
644+
ExpressionContextType::Runtime(ref mut rctx) => Ok(rctx),
645+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
646+
Err(Error::UnexpectedOperationInConstContext(span))
647+
}
648+
}?;
649+
650+
rctx.block
651+
.extend(rctx.emitter.finish(&rctx.function.expressions));
652+
rctx.emitter.start(&rctx.function.expressions);
653+
654+
let nested_rctx = LocalExpressionContext {
655+
local_table: rctx.local_table,
656+
function: rctx.function,
657+
block: &mut block,
658+
emitter: rctx.emitter,
659+
typifier: rctx.typifier,
660+
local_expression_kind_tracker: rctx.local_expression_kind_tracker,
661+
};
662+
let mut nested_ctx = ExpressionContext {
663+
expr_type: ExpressionContextType::Runtime(nested_rctx),
664+
ast_expressions: self.ast_expressions,
665+
types: self.types,
666+
globals: self.globals,
667+
module: self.module,
668+
const_typifier: self.const_typifier,
669+
layouter: self.layouter,
670+
global_expression_kind_tracker: self.global_expression_kind_tracker,
671+
};
672+
let ret = f(&mut nested_ctx)?;
673+
674+
block.extend(rctx.emitter.finish(&rctx.function.expressions));
675+
rctx.emitter.start(&rctx.function.expressions);
676+
677+
Ok((ret, block))
678+
}
679+
617680
fn gather_component(
618681
&mut self,
619682
expr: Handle<ir::Expression>,
@@ -2375,6 +2438,130 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
23752438
expr.try_map(|handle| ctx.append_expression(handle, span))
23762439
}
23772440

2441+
/// Generate IR for the short-circuiting operators `&&` and `||`.
2442+
///
2443+
/// `binary` has already lowered the LHS expression and resolved its type.
2444+
fn logical(
2445+
&mut self,
2446+
op: crate::BinaryOperator,
2447+
left: Handle<crate::Expression>,
2448+
right: Handle<ast::Expression<'source>>,
2449+
span: Span,
2450+
ctx: &mut ExpressionContext<'source, '_, '_>,
2451+
) -> Result<'source, Typed<crate::Expression>> {
2452+
debug_assert!(
2453+
op == crate::BinaryOperator::LogicalAnd || op == crate::BinaryOperator::LogicalOr
2454+
);
2455+
2456+
if ctx.is_runtime() {
2457+
// To simulate short-circuiting behavior, we want to generate IR
2458+
// like the following for `&&`. For `||`, the condition is `!_lhs`
2459+
// and the else value is `true`.
2460+
//
2461+
// var _e0: bool;
2462+
// if _lhs {
2463+
// _e0 = _rhs;
2464+
// } else {
2465+
// _e0 = false;
2466+
// }
2467+
2468+
let (condition, else_val) = if op == crate::BinaryOperator::LogicalAnd {
2469+
let condition = left;
2470+
let else_val = ctx.append_expression(
2471+
crate::Expression::Literal(crate::Literal::Bool(false)),
2472+
span,
2473+
)?;
2474+
(condition, else_val)
2475+
} else {
2476+
let condition = ctx.append_expression(
2477+
crate::Expression::Unary {
2478+
op: crate::UnaryOperator::LogicalNot,
2479+
expr: left,
2480+
},
2481+
span,
2482+
)?;
2483+
let else_val = ctx.append_expression(
2484+
crate::Expression::Literal(crate::Literal::Bool(true)),
2485+
span,
2486+
)?;
2487+
(condition, else_val)
2488+
};
2489+
2490+
let bool_ty = ctx.ensure_type_exists(crate::TypeInner::Scalar(crate::Scalar::BOOL));
2491+
2492+
let rctx = ctx.runtime_expression_ctx(span)?;
2493+
let result_var = rctx.function.local_variables.append(
2494+
crate::LocalVariable {
2495+
name: None,
2496+
ty: bool_ty,
2497+
init: None,
2498+
},
2499+
span,
2500+
);
2501+
let pointer =
2502+
ctx.append_expression(crate::Expression::LocalVariable(result_var), span)?;
2503+
2504+
let (right, mut accept) = ctx.with_nested_runtime_expression_ctx(span, |ctx| {
2505+
let right = self.expression_for_abstract(right, ctx)?;
2506+
ctx.grow_types(right)?;
2507+
Ok(right)
2508+
})?;
2509+
2510+
accept.push(
2511+
crate::Statement::Store {
2512+
pointer,
2513+
value: right,
2514+
},
2515+
span,
2516+
);
2517+
2518+
let mut reject = crate::Block::with_capacity(1);
2519+
reject.push(
2520+
crate::Statement::Store {
2521+
pointer,
2522+
value: else_val,
2523+
},
2524+
span,
2525+
);
2526+
2527+
let rctx = ctx.runtime_expression_ctx(span)?;
2528+
rctx.block.push(
2529+
crate::Statement::If {
2530+
condition,
2531+
accept,
2532+
reject,
2533+
},
2534+
span,
2535+
);
2536+
2537+
Ok(Typed::Reference(crate::Expression::LocalVariable(
2538+
result_var,
2539+
)))
2540+
} else {
2541+
let left_expr = ctx.get(left);
2542+
// Constant or override context in either function or module scope
2543+
let &crate::Expression::Literal(crate::Literal::Bool(left_val)) = left_expr else {
2544+
return Err(Box::new(Error::NotBool(span)));
2545+
};
2546+
2547+
if op == crate::BinaryOperator::LogicalAnd && !left_val
2548+
|| op == crate::BinaryOperator::LogicalOr && left_val
2549+
{
2550+
// Short-circuit behavior: don't evaluate the RHS. Ideally we
2551+
// would do _some_ validity checks of the RHS here, but that's
2552+
// tricky, because the RHS is allowed to have things that aren't
2553+
// legal in const contexts.
2554+
2555+
Ok(Typed::Plain(left_expr.clone()))
2556+
} else {
2557+
let right = self.expression_for_abstract(right, ctx)?;
2558+
ctx.grow_types(right)?;
2559+
2560+
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
2561+
}
2562+
}
2563+
}
2564+
23782565
fn binary(
23792566
&mut self,
23802567
op: ir::BinaryOperator,
@@ -2383,57 +2570,74 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
23832570
span: Span,
23842571
ctx: &mut ExpressionContext<'source, '_, '_>,
23852572
) -> Result<'source, Typed<ir::Expression>> {
2386-
// Load both operands.
2387-
let mut left = self.expression_for_abstract(left, ctx)?;
2388-
let mut right = self.expression_for_abstract(right, ctx)?;
2389-
2390-
// Convert `scalar op vector` to `vector op vector` by introducing
2391-
// `Splat` expressions.
2392-
ctx.binary_op_splat(op, &mut left, &mut right)?;
2393-
2394-
// Apply automatic conversions.
2395-
match op {
2396-
ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight => {
2397-
// Shift operators require the right operand to be `u32` or
2398-
// `vecN<u32>`. We can let the validator sort out vector length
2399-
// issues, but the right operand must be, or convert to, a u32 leaf
2400-
// scalar.
2401-
right =
2402-
ctx.try_automatic_conversion_for_leaf_scalar(right, ir::Scalar::U32, span)?;
2403-
2404-
// Additionally, we must concretize the left operand if the right operand
2405-
// is not a const-expression.
2406-
// See https://www.w3.org/TR/WGSL/#overload-resolution-section.
2407-
//
2408-
// 2. Eliminate any candidate where one of its subexpressions resolves to
2409-
// an abstract type after feasible automatic conversions, but another of
2410-
// the candidate’s subexpressions is not a const-expression.
2411-
//
2412-
// We only have to explicitly do so for shifts as their operands may be
2413-
// of different types - for other binary ops this is achieved by finding
2414-
// the conversion consensus for both operands.
2415-
if !ctx.is_const(right) {
2416-
left = ctx.concretize(left)?;
2417-
}
2573+
if op == ir::BinaryOperator::LogicalAnd || op == ir::BinaryOperator::LogicalOr {
2574+
let left = self.expression_for_abstract(left, ctx)?;
2575+
ctx.grow_types(left)?;
2576+
2577+
if !matches!(
2578+
resolve_inner!(ctx, left),
2579+
&ir::TypeInner::Scalar(ir::Scalar::BOOL)
2580+
) {
2581+
// Pass it through as-is, will fail validation
2582+
let right = self.expression_for_abstract(right, ctx)?;
2583+
ctx.grow_types(right)?;
2584+
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
2585+
} else {
2586+
self.logical(op, left, right, span, ctx)
24182587
}
2588+
} else {
2589+
// Load both operands.
2590+
let mut left = self.expression_for_abstract(left, ctx)?;
2591+
let mut right = self.expression_for_abstract(right, ctx)?;
2592+
2593+
// Convert `scalar op vector` to `vector op vector` by introducing
2594+
// `Splat` expressions.
2595+
ctx.binary_op_splat(op, &mut left, &mut right)?;
2596+
2597+
// Apply automatic conversions.
2598+
match op {
2599+
ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight => {
2600+
// Shift operators require the right operand to be `u32` or
2601+
// `vecN<u32>`. We can let the validator sort out vector length
2602+
// issues, but the right operand must be, or convert to, a u32 leaf
2603+
// scalar.
2604+
right =
2605+
ctx.try_automatic_conversion_for_leaf_scalar(right, ir::Scalar::U32, span)?;
2606+
2607+
// Additionally, we must concretize the left operand if the right operand
2608+
// is not a const-expression.
2609+
// See https://www.w3.org/TR/WGSL/#overload-resolution-section.
2610+
//
2611+
// 2. Eliminate any candidate where one of its subexpressions resolves to
2612+
// an abstract type after feasible automatic conversions, but another of
2613+
// the candidate’s subexpressions is not a const-expression.
2614+
//
2615+
// We only have to explicitly do so for shifts as their operands may be
2616+
// of different types - for other binary ops this is achieved by finding
2617+
// the conversion consensus for both operands.
2618+
if !ctx.is_const(right) {
2619+
left = ctx.concretize(left)?;
2620+
}
2621+
}
24192622

2420-
// All other operators follow the same pattern: reconcile the
2421-
// scalar leaf types. If there's no reconciliation possible,
2422-
// leave the expressions as they are: validation will report the
2423-
// problem.
2424-
_ => {
2425-
ctx.grow_types(left)?;
2426-
ctx.grow_types(right)?;
2427-
if let Ok(consensus_scalar) =
2428-
ctx.automatic_conversion_consensus([left, right].iter())
2429-
{
2430-
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
2431-
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
2623+
// All other operators follow the same pattern: reconcile the
2624+
// scalar leaf types. If there's no reconciliation possible,
2625+
// leave the expressions as they are: validation will report the
2626+
// problem.
2627+
_ => {
2628+
ctx.grow_types(left)?;
2629+
ctx.grow_types(right)?;
2630+
if let Ok(consensus_scalar) =
2631+
ctx.automatic_conversion_consensus([left, right].iter())
2632+
{
2633+
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
2634+
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
2635+
}
24322636
}
24332637
}
2434-
}
24352638

2436-
Ok(Typed::Plain(ir::Expression::Binary { op, left, right }))
2639+
Ok(Typed::Plain(ir::Expression::Binary { op, left, right }))
2640+
}
24372641
}
24382642

24392643
/// Generate Naga IR for call expressions and statements, and type

naga/tests/in/wgsl/operators.wgsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ fn bool_cast(x: vec3<f32>) -> vec3<f32> {
4040
return vec3<f32>(y);
4141
}
4242

43+
fn p() -> bool { return true; }
44+
fn q() -> bool { return false; }
45+
fn r() -> bool { return true; }
46+
fn s() -> bool { return false; }
47+
4348
fn logical() {
4449
let t = true;
4550
let f = false;
@@ -55,6 +60,7 @@ fn logical() {
5560
let bitwise_or1 = vec3(t) | vec3(f);
5661
let bitwise_and0 = t & f;
5762
let bitwise_and1 = vec4(t) & vec4(f);
63+
let short_circuit = (p() || q()) && (r() || s());
5864
}
5965

6066
fn arithmetic() {

0 commit comments

Comments
 (0)