Skip to content

Commit 0679d07

Browse files
[FXC-2301] Fixed the translator not recursively finding solver variable names (#1365) (#1368)
Co-authored-by: Ben <106089368+benflexcompute@users.noreply.github.com>
1 parent 29b84c9 commit 0679d07

File tree

4 files changed

+251
-12
lines changed

4 files changed

+251
-12
lines changed

flow360/component/simulation/outputs/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _validate_improper_surface_field_usage(cls, value: UniqueItemList):
134134
):
135135
continue
136136
surface_solver_variable_names = output_item.value.solver_variable_names(
137-
variable_type="Surface"
137+
recursive=True, variable_type="Surface"
138138
)
139139
if len(surface_solver_variable_names) > 0:
140140
raise ValueError(

flow360/component/simulation/translator/solver_translator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def user_variable_to_udf(
577577

578578
def _prepare_prepending_code(expression: Expression):
579579
prepending_code = []
580-
for name in sorted(expression.solver_variable_names()):
580+
for name in sorted(expression.solver_variable_names(recursive=True)):
581581
if not udf_prepending_code.get(name):
582582
continue
583583
if name == "solution.temperature" and input_params.has_solid():
@@ -601,7 +601,6 @@ def _prepare_prepending_code(expression: Expression):
601601

602602
expression_length = expression.length
603603
prepending_code = _prepare_prepending_code(expression=expression)
604-
605604
if expression_length == 0: # Scalar output requested
606605
expression = expression.evaluate(raise_on_non_evaluable=False, force_evaluate=False)
607606
if offset != 0:

flow360/component/simulation/user_code/core/types.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,9 @@ def ensure_dependent_feature_enabled(self) -> str:
847847
validation_info = get_validation_info()
848848
if validation_info is None or self.expression not in validation_info.referenced_expressions:
849849
return self
850-
851-
for solver_variable_name in self.solver_variable_names():
850+
# Setting recursive to False to avoid recursive error message.
851+
# All user variables will be checked anyways.
852+
for solver_variable_name in self.solver_variable_names(recursive=False):
852853
if solver_variable_name in _feature_requirement_map:
853854
if not _feature_requirement_map[solver_variable_name][0](validation_info):
854855
raise ValueError(
@@ -902,15 +903,65 @@ def user_variable_names(self):
902903
return names
903904

904905
def solver_variable_names(
905-
self, variable_type: Literal["Volume", "Surface", "Scalar", "All"] = "All"
906+
self,
907+
recursive: bool,
908+
variable_type: Literal["Volume", "Surface", "Scalar", "All"] = "All",
906909
):
907-
"""Get list of solver variable names used in expression."""
908-
expr = expr_to_model(self.expression, default_context)
909-
names = expr.used_names()
910-
names = [name for name in names if name in _solver_variables]
910+
"""Get list of solver variable names used in expression, recursively checking user variables.
911+
912+
Params:
913+
-------
914+
- variable_type: The type of variable to get the names of.
915+
- recursive: Whether to recursively check user variables for solver variables.
916+
"""
917+
918+
def _get_solver_variable_names_recursive(
919+
expression: Expression, visited: set[str], recursive: bool
920+
) -> set[str]:
921+
"""Recursively get solver variable names from expression and its user variables."""
922+
solver_names = set()
923+
924+
# Prevent infinite recursion by tracking visited expressions
925+
expr_str = str(expression)
926+
if expr_str in visited:
927+
return solver_names
928+
visited.add(expr_str)
929+
930+
# Get solver variables directly from this expression
931+
expr = expr_to_model(expression.expression, default_context)
932+
names = expr.used_names()
933+
direct_solver_names = [name for name in names if name in _solver_variables]
934+
solver_names.update(direct_solver_names)
935+
936+
if not recursive:
937+
return solver_names
938+
939+
# Get user variables from this expression and recursively check their values
940+
user_vars = expression.user_variables()
941+
for user_var in user_vars:
942+
try:
943+
if isinstance(user_var.value, Expression):
944+
# Recursively check the user variable's expression
945+
recursive_solver_names = _get_solver_variable_names_recursive(
946+
user_var.value, visited, recursive
947+
)
948+
solver_names.update(recursive_solver_names)
949+
except (ValueError, AttributeError):
950+
# Handle cases where user variable might not be properly defined
951+
pass
952+
953+
return solver_names
954+
955+
# Start the recursive search
956+
all_solver_names = _get_solver_variable_names_recursive(self, set(), recursive)
957+
958+
# Filter by variable type if specified
911959
if variable_type != "All":
912-
names = [name for name in names if _solver_variables[name] == variable_type]
913-
return names
960+
all_solver_names = {
961+
name for name in all_solver_names if _solver_variables[name] == variable_type
962+
}
963+
964+
return list(all_solver_names)
914965

915966
def to_solver_code(self, params):
916967
"""Convert to solver readable code."""

tests/simulation/test_expressions.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,39 @@ def test_udf_generator():
756756
== "double ___velocity[3];___velocity[0] = primitiveVars[1] * velocityScale;___velocity[1] = primitiveVars[2] * velocityScale;___velocity[2] = primitiveVars[3] * velocityScale;pos_vel[0] = (+___velocity[0] * 5.0); pos_vel[1] = (+___velocity[1] * 5.0); pos_vel[2] = (+___velocity[2] * 5.0);"
757757
)
758758

759+
density_kg_per_m3 = UserVariable(name="density_kg_per_m3", value=solution.density).in_units(
760+
new_unit="kg /m**3"
761+
)
762+
velocity_metric = UserVariable(name="velocity_metric", value=solution.velocity).in_units(
763+
new_unit="m/s"
764+
)
765+
mass_flow_rate_kg_per_s_per_m2 = UserVariable(
766+
name="mass_flow_rate_kg_per_s",
767+
value=math.dot(velocity_metric, solution.node_unit_normal) * density_kg_per_m3,
768+
).in_units(new_unit="kg/s/m**2")
769+
770+
assert user_variable_to_udf(mass_flow_rate_kg_per_s_per_m2, input_params=params).expression == (
771+
"double ___density;"
772+
"___density = usingLiquidAsMaterial ? 1.0 : primitiveVars[0];"
773+
"double ___node_unit_normal[3];"
774+
"double ___normalMag = magnitude(nodeNormals);"
775+
"for (int i = 0; i < 3; i++)"
776+
"{"
777+
"___node_unit_normal[i] = nodeNormals[i] / ___normalMag;"
778+
"}"
779+
"double ___velocity[3];"
780+
"___velocity[0] = primitiveVars[1] * velocityScale;"
781+
"___velocity[1] = primitiveVars[2] * velocityScale;"
782+
"___velocity[2] = primitiveVars[3] * velocityScale;"
783+
"mass_flow_rate_kg_per_s = ("
784+
"((("
785+
"(___velocity[0] * ___node_unit_normal[0]) + "
786+
"(___velocity[1] * ___node_unit_normal[1])"
787+
") + "
788+
"(___velocity[2] * ___node_unit_normal[2])"
789+
") * ___density) * 5000.0);"
790+
)
791+
759792

760793
def test_project_variables_serialization():
761794
ccc = UserVariable(name="ccc", value=12 * u.m / u.s, description="ccc description")
@@ -1760,3 +1793,159 @@ def test_correct_expression_error_location():
17601793
"operator for unyt_arrays with units 'dimensionless' (dimensions '1') and 'm' (dimensions '(length)') is not well defined."
17611794
in errors[0]["msg"]
17621795
)
1796+
1797+
1798+
def test_solver_variable_names_recursive():
1799+
"""Test the recursive solver_variable_names method with proper physical dimensions."""
1800+
1801+
# Test 1: Direct solver variable usage
1802+
expr1 = Expression(expression="solution.density + solution.pressure")
1803+
solver_vars = expr1.solver_variable_names(recursive=True)
1804+
assert set(solver_vars) == {"solution.density", "solution.pressure"}
1805+
1806+
# Test 2: No solver variables - pure mathematical expression
1807+
expr2 = Expression(expression="1.0 + 2.0 * 3.0")
1808+
solver_vars = expr2.solver_variable_names(recursive=True)
1809+
assert solver_vars == []
1810+
1811+
# Test 3: Simple user variable with solver variable (dimensionally consistent)
1812+
user_var1 = UserVariable(name="my_density", value=solution.density)
1813+
expr3 = Expression(expression="my_density * 2.0")
1814+
solver_vars = expr3.solver_variable_names(recursive=True)
1815+
assert solver_vars == ["solution.density"]
1816+
1817+
# Test 4: Nested user variables - velocity component access
1818+
user_var2 = UserVariable(name="vel_x_comp", value=solution.velocity[0])
1819+
user_var3 = UserVariable(name="scaled_vel", value=user_var2 * 2.0)
1820+
expr4 = Expression(expression="scaled_vel")
1821+
solver_vars = expr4.solver_variable_names(recursive=True)
1822+
assert solver_vars == ["solution.velocity"]
1823+
1824+
# Test 5: Multiple levels of nesting with dimensional consistency
1825+
user_var4 = UserVariable(name="rho_squared", value=user_var1 * user_var1)
1826+
user_var5 = UserVariable(name="momentum_like", value=user_var4 * user_var3)
1827+
expr5 = Expression(expression="momentum_like")
1828+
solver_vars = expr5.solver_variable_names(recursive=True)
1829+
assert set(solver_vars) == {"solution.density", "solution.velocity"}
1830+
1831+
# Test 6: Mixed direct and indirect solver variables with proper dimensions
1832+
expr6 = Expression(expression="momentum_like + solution.pressure * solution.density")
1833+
solver_vars = expr6.solver_variable_names(recursive=True)
1834+
assert set(solver_vars) == {"solution.density", "solution.velocity", "solution.pressure"}
1835+
1836+
# Test 7: User variable with dimensionless value
1837+
user_var6 = UserVariable(name="mach_number", value=0.3)
1838+
expr7 = Expression(expression="mach_number * solution.velocity[0]")
1839+
solver_vars = expr7.solver_variable_names(recursive=True)
1840+
assert solver_vars == ["solution.velocity"]
1841+
1842+
# Test 8: Filter by variable type - Volume variables only
1843+
expr8 = Expression(expression="solution.density + solution.velocity[0] + control.MachRef")
1844+
volume_vars = expr8.solver_variable_names(variable_type="Volume", recursive=True)
1845+
assert set(volume_vars) == {"solution.density", "solution.velocity"}
1846+
1847+
# Test 9: Filter by variable type - Scalar variables only
1848+
scalar_vars = expr8.solver_variable_names(variable_type="Scalar", recursive=True)
1849+
assert scalar_vars == ["control.MachRef"]
1850+
1851+
# Test 10: Complex flow physics expression with proper dimensions
1852+
user_var7 = UserVariable(
1853+
name="dyn_press", value=0.5 * solution.density * solution.velocity[0] * solution.velocity[0]
1854+
)
1855+
user_var8 = UserVariable(name="tot_press", value=solution.pressure + user_var7)
1856+
expr9 = Expression(expression="tot_press")
1857+
solver_vars = expr9.solver_variable_names(recursive=True)
1858+
assert set(solver_vars) == {"solution.density", "solution.velocity", "solution.pressure"}
1859+
1860+
# Test 11: Temperature-based expressions (for compressible flow)
1861+
user_var9 = UserVariable(
1862+
name="temp_ratio", value=solution.temperature / 300.0
1863+
) # Reference temperature
1864+
user_var10 = UserVariable(name="scaled_density", value=solution.density * user_var9)
1865+
expr10 = Expression(expression="scaled_density")
1866+
solver_vars = expr10.solver_variable_names(recursive=True)
1867+
assert set(solver_vars) == {"solution.temperature", "solution.density"}
1868+
1869+
# Test 12: Vector operations with proper indexing
1870+
user_var11 = UserVariable(
1871+
name="vel_mag_sq",
1872+
value=solution.velocity[0] * solution.velocity[0]
1873+
+ solution.velocity[1] * solution.velocity[1]
1874+
+ solution.velocity[2] * solution.velocity[2],
1875+
)
1876+
expr11 = Expression(expression="vel_mag_sq")
1877+
solver_vars = expr11.solver_variable_names(recursive=True)
1878+
assert solver_vars == ["solution.velocity"]
1879+
1880+
# Test 13: Deep nesting with multiple physics variables
1881+
user_var12 = UserVariable(name="ke_calc", value=0.5 * solution.density * user_var11)
1882+
user_var13 = UserVariable(name="te_calc", value=user_var12 + solution.pressure / (1.4 - 1.0))
1883+
user_var14 = UserVariable(name="epv_calc", value=user_var13 / solution.temperature)
1884+
expr12 = Expression(expression="epv_calc")
1885+
solver_vars = expr12.solver_variable_names(recursive=True)
1886+
expected_vars = {
1887+
"solution.density",
1888+
"solution.velocity",
1889+
"solution.pressure",
1890+
"solution.temperature",
1891+
}
1892+
assert set(solver_vars) == expected_vars
1893+
1894+
# Test 14: Mathematical functions with solver variables (using dimensionless ratios)
1895+
user_var15 = UserVariable(
1896+
name="rho_ratio", value=solution.density / solution.density
1897+
) # Dimensionless ratio
1898+
user_var16 = UserVariable(
1899+
name="complex_func", value=Expression(expression="math.exp(rho_ratio)")
1900+
)
1901+
expr13 = Expression(expression="complex_func")
1902+
solver_vars = expr13.solver_variable_names(recursive=True)
1903+
assert solver_vars == ["solution.density"]
1904+
1905+
# Test 15: Control variables in time-dependent expressions
1906+
user_var17 = UserVariable(
1907+
name="time_scaled_vel", value=solution.velocity[0] * control.timeStepSize
1908+
)
1909+
expr14 = Expression(expression="time_scaled_vel") # Just use the time-scaled velocity
1910+
solver_vars = expr14.solver_variable_names(recursive=True)
1911+
assert set(solver_vars) == {"solution.velocity", "control.timeStepSize"}
1912+
1913+
# Test 16: Edge case - circular reference prevention
1914+
user_var18 = UserVariable(name="base_var", value=solution.pressure)
1915+
user_var19 = UserVariable(name="derived_var", value=user_var18 * 2.0)
1916+
user_var20 = UserVariable(name="twice_derived", value=user_var19 + user_var18)
1917+
expr15 = Expression(expression="twice_derived")
1918+
solver_vars = expr15.solver_variable_names(recursive=True)
1919+
assert solver_vars == ["solution.pressure"]
1920+
1921+
# Test 17: Mixed dimensioned and dimensionless variables
1922+
user_var21 = UserVariable(name="re_num", value=1e6) # Dimensionless
1923+
user_var22 = UserVariable(
1924+
name="char_vel", value=user_var21 * solution.mut / (solution.density * 1.0)
1925+
) # Length = 1.0 m
1926+
expr16 = Expression(expression="char_vel")
1927+
solver_vars = expr16.solver_variable_names(recursive=True)
1928+
assert set(solver_vars) == {"solution.mut", "solution.density"}
1929+
1930+
# Test 18: Surface variables if available
1931+
try:
1932+
# Only test if surface variables exist
1933+
expr17 = Expression(expression="solution.Cp + solution.density")
1934+
solver_vars = expr17.solver_variable_names(variable_type="Surface", recursive=True)
1935+
# Check if surface variables are found
1936+
surface_vars = [var for var in solver_vars if "Cp" in var]
1937+
if surface_vars:
1938+
assert "solution.Cp" in solver_vars
1939+
except (AttributeError, ValueError):
1940+
# Skip if surface variables not available in test environment
1941+
pass
1942+
1943+
# Test 19: All variable types combined
1944+
expr18 = Expression(expression="solution.density + solution.velocity[0] + control.MachRef")
1945+
all_vars = expr18.solver_variable_names(variable_type="All", recursive=True)
1946+
assert set(all_vars) == {"solution.density", "solution.velocity", "control.MachRef"}
1947+
1948+
# Test 20: Simple expression with just a constant
1949+
expr19 = Expression(expression="42.0")
1950+
solver_vars = expr19.solver_variable_names(recursive=True)
1951+
assert solver_vars == []

0 commit comments

Comments
 (0)