Skip to content

Commit 2f255ed

Browse files
authored
[naga] Use const ctx instead of global ctx for type resolution (#6935)
Signed-off-by: sagudev <[email protected]>
1 parent e95f6d6 commit 2f255ed

File tree

8 files changed

+244
-190
lines changed

8 files changed

+244
-190
lines changed

naga/src/front/wgsl/lower/construction.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ impl<'source> Lowerer<'source, '_> {
579579
}
580580
ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
581581
ast::ConstructorType::Vector { size, ty, ty_span } => {
582-
let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?;
582+
let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
583583
let scalar = match ctx.module.types[ty].inner {
584584
crate::TypeInner::Scalar(sc) => sc,
585585
_ => return Err(Error::UnknownScalarType(ty_span)),
@@ -596,7 +596,7 @@ impl<'source> Lowerer<'source, '_> {
596596
ty,
597597
ty_span,
598598
} => {
599-
let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?;
599+
let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
600600
let scalar = match ctx.module.types[ty].inner {
601601
crate::TypeInner::Scalar(sc) => sc,
602602
_ => return Err(Error::UnknownScalarType(ty_span)),
@@ -613,8 +613,8 @@ impl<'source> Lowerer<'source, '_> {
613613
}
614614
ast::ConstructorType::PartialArray => Constructor::PartialArray,
615615
ast::ConstructorType::Array { base, size } => {
616-
let base = self.resolve_ast_type(base, &mut ctx.as_global())?;
617-
let size = self.array_size(size, &mut ctx.as_global())?;
616+
let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
617+
let size = self.array_size(size, &mut ctx.as_const())?;
618618

619619
ctx.layouter.update(ctx.module.to_ctx()).unwrap();
620620
let stride = ctx.layouter[base].to_stride();

naga/src/front/wgsl/lower/mod.rs

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
244244
}
245245
}
246246

247+
#[allow(dead_code)]
247248
fn as_global(&mut self) -> GlobalContext<'a, '_, '_> {
248249
GlobalContext {
249250
ast_expressions: self.ast_expressions,
@@ -468,29 +469,28 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
468469
.map_err(|e| Error::ConstantEvaluatorError(e.into(), span))
469470
}
470471

471-
fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
472+
fn const_eval_expr_to_u32(
473+
&self,
474+
handle: Handle<crate::Expression>,
475+
) -> Result<u32, crate::proc::U32EvalError> {
472476
match self.expr_type {
473477
ExpressionContextType::Runtime(ref ctx) => {
474478
if !ctx.local_expression_kind_tracker.is_const(handle) {
475-
return None;
479+
return Err(crate::proc::U32EvalError::NonConst);
476480
}
477481

478482
self.module
479483
.to_ctx()
480484
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
481-
.ok()
482485
}
483486
ExpressionContextType::Constant(Some(ref ctx)) => {
484487
assert!(ctx.local_expression_kind_tracker.is_const(handle));
485488
self.module
486489
.to_ctx()
487490
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
488-
.ok()
489491
}
490-
ExpressionContextType::Constant(None) => {
491-
self.module.to_ctx().eval_expr_to_u32(handle).ok()
492-
}
493-
ExpressionContextType::Override => None,
492+
ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_u32(handle),
493+
ExpressionContextType::Override => Err(crate::proc::U32EvalError::NonConst),
494494
}
495495
}
496496

@@ -1069,7 +1069,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
10691069
}
10701070
ast::GlobalDeclKind::Var(ref v) => {
10711071
let explicit_ty =
1072-
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx))
1072+
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const()))
10731073
.transpose()?;
10741074

10751075
let (ty, initializer) =
@@ -1102,7 +1102,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
11021102
let mut ectx = ctx.as_const();
11031103

11041104
let explicit_ty =
1105-
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx.as_global()))
1105+
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx))
11061106
.transpose()?;
11071107

11081108
let (ty, init) =
@@ -1123,7 +1123,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
11231123
}
11241124
ast::GlobalDeclKind::Override(ref o) => {
11251125
let explicit_ty =
1126-
o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx))
1126+
o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const()))
11271127
.transpose()?;
11281128

11291129
let mut ectx = ctx.as_override();
@@ -1165,7 +1165,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
11651165
let ty = self.resolve_named_ast_type(
11661166
alias.ty,
11671167
Some(alias.name.name.to_string()),
1168-
&mut ctx,
1168+
&mut ctx.as_const(),
11691169
)?;
11701170
ctx.globals
11711171
.insert(alias.name.name, LoweredGlobalDecl::Type(ty));
@@ -1263,7 +1263,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
12631263
.iter()
12641264
.enumerate()
12651265
.map(|(i, arg)| -> Result<_, Error<'_>> {
1266-
let ty = self.resolve_ast_type(arg.ty, ctx)?;
1266+
let ty = self.resolve_ast_type(arg.ty, &mut ctx.as_const())?;
12671267
let expr = expressions
12681268
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
12691269
local_table.insert(arg.handle, Declared::Runtime(Typed::Plain(expr)));
@@ -1282,7 +1282,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
12821282
.result
12831283
.as_ref()
12841284
.map(|res| -> Result<_, Error<'_>> {
1285-
let ty = self.resolve_ast_type(res.ty, ctx)?;
1285+
let ty = self.resolve_ast_type(res.ty, &mut ctx.as_const())?;
12861286
Ok(crate::FunctionResult {
12871287
ty,
12881288
binding: self.binding(&res.binding, ty, ctx)?,
@@ -1440,9 +1440,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
14401440
// optimization.
14411441
ctx.local_expression_kind_tracker.force_non_const(value);
14421442

1443-
let explicit_ty =
1444-
l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
1445-
.transpose()?;
1443+
let explicit_ty = l
1444+
.ty
1445+
.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_const(block, &mut emitter)))
1446+
.transpose()?;
14461447

14471448
if let Some(ty) = explicit_ty {
14481449
let mut ctx = ctx.as_expression(block, &mut emitter);
@@ -1469,12 +1470,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
14691470
return Ok(());
14701471
}
14711472
ast::LocalDecl::Var(ref v) => {
1472-
let explicit_ty =
1473-
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
1474-
.transpose()?;
1475-
14761473
let mut emitter = Emitter::default();
14771474
emitter.start(&ctx.function.expressions);
1475+
1476+
let explicit_ty =
1477+
v.ty.map(|ast| {
1478+
self.resolve_ast_type(ast, &mut ctx.as_const(block, &mut emitter))
1479+
})
1480+
.transpose()?;
1481+
14781482
let mut ectx = ctx.as_expression(block, &mut emitter);
14791483
let (ty, initializer) =
14801484
self.type_and_init(v.name, v.init, explicit_ty, &mut ectx)?;
@@ -1533,11 +1537,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
15331537
let ectx = &mut ctx.as_const(block, &mut emitter);
15341538

15351539
let explicit_ty =
1536-
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx.as_global()))
1540+
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx.as_const()))
15371541
.transpose()?;
15381542

1539-
let (_ty, init) =
1540-
self.type_and_init(c.name, Some(c.init), explicit_ty, ectx)?;
1543+
let (_ty, init) = self.type_and_init(
1544+
c.name,
1545+
Some(c.init),
1546+
explicit_ty,
1547+
&mut ectx.as_const(),
1548+
)?;
15411549
let init = init.expect("Local const must have init");
15421550

15431551
block.extend(emitter.finish(&ctx.function.expressions));
@@ -1992,7 +2000,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19922000
}
19932001
}
19942002

1995-
lowered_base.map(|base| match ctx.const_access(index) {
2003+
lowered_base.map(|base| match ctx.const_eval_expr_to_u32(index).ok() {
19962004
Some(index) => crate::Expression::AccessIndex { base, index },
19972005
None => crate::Expression::Access { base, index },
19982006
})
@@ -2069,7 +2077,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
20692077
}
20702078
ast::Expression::Bitcast { expr, to, ty_span } => {
20712079
let expr = self.expression(expr, ctx)?;
2072-
let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?;
2080+
let to_resolved = self.resolve_ast_type(to, &mut ctx.as_const())?;
20732081

20742082
let element_scalar = match ctx.module.types[to_resolved].inner {
20752083
crate::TypeInner::Scalar(scalar) => scalar,
@@ -3051,7 +3059,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
30513059
let mut members = Vec::with_capacity(s.members.len());
30523060

30533061
for member in s.members.iter() {
3054-
let ty = self.resolve_ast_type(member.ty, ctx)?;
3062+
let ty = self.resolve_ast_type(member.ty, &mut ctx.as_const())?;
30553063

30563064
ctx.layouter.update(ctx.module.to_ctx()).unwrap();
30573065

@@ -3138,25 +3146,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31383146
fn array_size(
31393147
&mut self,
31403148
size: ast::ArraySize<'source>,
3141-
ctx: &mut GlobalContext<'source, '_, '_>,
3149+
ctx: &mut ExpressionContext<'source, '_, '_>,
31423150
) -> Result<crate::ArraySize, Error<'source>> {
31433151
Ok(match size {
31443152
ast::ArraySize::Constant(expr) => {
31453153
let span = ctx.ast_expressions.get_span(expr);
31463154
let const_expr = self.expression(expr, &mut ctx.as_const());
31473155
match const_expr {
31483156
Ok(value) => {
3149-
let len =
3150-
ctx.module.to_ctx().eval_expr_to_u32(value).map_err(
3151-
|err| match err {
3152-
crate::proc::U32EvalError::NonConst => {
3153-
Error::ExpectedConstExprConcreteIntegerScalar(span)
3154-
}
3155-
crate::proc::U32EvalError::Negative => {
3156-
Error::ExpectedPositiveArrayLength(span)
3157-
}
3158-
},
3159-
)?;
3157+
let len = ctx.const_eval_expr_to_u32(value).map_err(|err| match err {
3158+
crate::proc::U32EvalError::NonConst => {
3159+
Error::ExpectedConstExprConcreteIntegerScalar(span)
3160+
}
3161+
crate::proc::U32EvalError::Negative => {
3162+
Error::ExpectedPositiveArrayLength(span)
3163+
}
3164+
})?;
31603165
let size =
31613166
NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
31623167
crate::ArraySize::Constant(size)
@@ -3167,7 +3172,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31673172
crate::proc::ConstantEvaluatorError::OverrideExpr => {
31683173
crate::ArraySize::Pending(self.array_size_override(
31693174
expr,
3170-
&mut ctx.as_override(),
3175+
&mut ctx.as_global().as_override(),
31713176
span,
31723177
)?)
31733178
}
@@ -3219,7 +3224,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
32193224
&mut self,
32203225
handle: Handle<ast::Type<'source>>,
32213226
name: Option<String>,
3222-
ctx: &mut GlobalContext<'source, '_, '_>,
3227+
ctx: &mut ExpressionContext<'source, '_, '_>,
32233228
) -> Result<Handle<crate::Type>, Error<'source>> {
32243229
let inner = match ctx.types[handle] {
32253230
ast::Type::Scalar(scalar) => scalar.to_inner_scalar(),
@@ -3257,7 +3262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
32573262
crate::TypeInner::Pointer { base, space }
32583263
}
32593264
ast::Type::Array { base, size } => {
3260-
let base = self.resolve_ast_type(base, ctx)?;
3265+
let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
32613266
let size = self.array_size(size, ctx)?;
32623267

32633268
ctx.layouter.update(ctx.module.to_ctx()).unwrap();
@@ -3297,14 +3302,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
32973302
}
32983303
};
32993304

3300-
Ok(ctx.ensure_type_exists(name, inner))
3305+
Ok(ctx.as_global().ensure_type_exists(name, inner))
33013306
}
33023307

33033308
/// Return a Naga `Handle<Type>` representing the front-end type `handle`.
33043309
fn resolve_ast_type(
33053310
&mut self,
33063311
handle: Handle<ast::Type<'source>>,
3307-
ctx: &mut GlobalContext<'source, '_, '_>,
3312+
ctx: &mut ExpressionContext<'source, '_, '_>,
33083313
) -> Result<Handle<crate::Type>, Error<'source>> {
33093314
self.resolve_named_ast_type(handle, None, ctx)
33103315
}

naga/tests/in/const-exprs.wgsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ fn main() {
1212
splat_of_constant();
1313
compose_of_constant();
1414
compose_of_splat();
15+
test_local_const();
1516
}
1617

1718
// Swizzle the value of nested Compose expressions.
@@ -109,3 +110,8 @@ fn relational() {
109110
var vec_all_false = all(vec4(vec3(vec2<bool>(), TRUE), false));
110111
var vec_all_true = all(vec4(true));
111112
}
113+
114+
fn test_local_const() {
115+
const local_const = 2;
116+
var arr: array<f32, local_const>;
117+
}

naga/tests/out/glsl/const-exprs.main.Compute.glsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ void compose_of_splat() {
7171
return;
7272
}
7373

74+
void test_local_const() {
75+
float arr[2] = float[2](0.0, 0.0);
76+
return;
77+
}
78+
7479
uint map_texture_kind(int texture_kind) {
7580
switch(texture_kind) {
7681
case 0: {
@@ -115,6 +120,7 @@ void main() {
115120
splat_of_constant();
116121
compose_of_constant();
117122
compose_of_splat();
123+
test_local_const();
118124
return;
119125
}
120126

naga/tests/out/hlsl/const-exprs.hlsl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ void compose_of_splat()
7777
return;
7878
}
7979

80+
void test_local_const()
81+
{
82+
float arr[2] = (float[2])0;
83+
84+
return;
85+
}
86+
8087
uint map_texture_kind(int texture_kind)
8188
{
8289
switch(texture_kind) {
@@ -128,5 +135,6 @@ void main()
128135
splat_of_constant();
129136
compose_of_constant();
130137
compose_of_splat();
138+
test_local_const();
131139
return;
132140
}

naga/tests/out/msl/const-exprs.msl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
using metal::uint;
66

7+
struct type_6 {
8+
float inner[2];
9+
};
710
constant uint TWO = 2u;
811
constant int THREE = 3;
912
constant bool TRUE = true;
@@ -76,6 +79,12 @@ void compose_of_splat(
7679
return;
7780
}
7881

82+
void test_local_const(
83+
) {
84+
type_6 arr = {};
85+
return;
86+
}
87+
7988
uint map_texture_kind(
8089
int texture_kind
8190
) {
@@ -125,5 +134,6 @@ kernel void main_(
125134
splat_of_constant();
126135
compose_of_constant();
127136
compose_of_splat();
137+
test_local_const();
128138
return;
129139
}

0 commit comments

Comments
 (0)