-
Notifications
You must be signed in to change notification settings - Fork 15
Scoping of temporaries within conditionals needs extra safeguards. #346
Description
Problem
When assigning to a temporary inside a conditional, due to scalarization, we can end up in a state where unexpected results occur even though the front-end should complain [1]:
if we look at the stencil:
def test_stencil(
out_field: Field[data_type],
helper: Field[gtscript.IJ, data_type],
):
with computation(FORWARD), interval(...):
if helper < 1:
tmp = 10
out_field[0, 0, 0] = tmp
helper = 3This code should raise some sort of UnboundLocalError since tmp is only set at the very first level and all following ones should not be able to read tmp. But since our scalarization sees that tmp can be scalarized because it is only accessed pointwise, what happens is that the first level writes to the scalar tmp and all subsequent levels read from it.
Based on the backend we get the following:
| Backend | Result |
|---|---|
| Debug | output = 10 |
| gt:cpu | output = 10 |
| dace:cpu | output = 10 |
| expected | UnboundLocalError |
Switching to if helper > 1: Does have the following effect:
| Backend | Result |
|---|---|
| Debug | UnboundLocalError |
| gt:cpu | output = 10 |
| dace:cpu | output = 10 |
| expected | UnboundLocalError |
Solutions
- Work on the scalarization pass in order to not scalarize tempoararies used in any conditionals
Note that this still generates code that is executing without errors and creates different results between backends ("debug" reads garbage data, c++ backends read 0's). This still does not solve the problem to make it easy for the user to understand that the code is wrong. But it at least does not keep lingering values around.
- Add guardrails in the front-end to check for usage like this and prohibit it
The difficult part here is that temporaries inside conditionals are an important feature that needs to be allowed.
if namelist_flag:
tmp = something
other code
if namelist_flag:
out = function(tmp)is a common pattern that should still work.
But in that case we need to check that all conditionals are the same and that everything that is evaluated is not touched in between which is not trivial
Notes
[1]: We currently get a "gt4py/cartesian/gtc/passes/gtir_pipeline.py:48: UserWarning: tmp may be uninitialized." but it is questionable if this is stong enough?
Reproducer
Full reproducer code to run standalone here: tempscope.py