Skip to content

Commit ec09bd1

Browse files
benflexcomputeflow360-auto-hotfix-bot
authored andcommitted
[FXC-2301] Fixed the translator not recursively finding solver variable names (#1365)
1 parent 16fedf5 commit ec09bd1

File tree

4 files changed

+251
-12
lines changed

4 files changed

+251
-12
lines changed

flow360/component/simulation/outputs/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _validate_improper_surface_field_usage(cls, value: UniqueItemList):
134134
):
135135
continue
136136
surface_solver_variable_names = output_item.value.solver_variable_names(
137-
variable_type="Surface"
137+
recursive=True, variable_type="Surface"
138138
)
139139
if len(surface_solver_variable_names) > 0:
140140
raise ValueError(

flow360/component/simulation/translator/solver_translator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def user_variable_to_udf(
609609

610610
def _prepare_prepending_code(expression: Expression):
611611
prepending_code = []
612-
for name in sorted(expression.solver_variable_names()):
612+
for name in sorted(expression.solver_variable_names(recursive=True)):
613613
if not udf_prepending_code.get(name):
614614
continue
615615
if name == "solution.temperature" and input_params.has_solid():
@@ -633,7 +633,6 @@ def _prepare_prepending_code(expression: Expression):
633633

634634
expression_length = expression.length
635635
prepending_code = _prepare_prepending_code(expression=expression)
636-
637636
if expression_length == 0: # Scalar output requested
638637
expression = expression.evaluate(raise_on_non_evaluable=False, force_evaluate=False)
639638
if offset != 0:

flow360/component/simulation/user_code/core/types.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -845,8 +845,9 @@ def ensure_dependent_feature_enabled(self) -> str:
845845
validation_info = get_validation_info()
846846
if validation_info is None or self.expression not in validation_info.referenced_expressions:
847847
return self
848-
849-
for solver_variable_name in self.solver_variable_names():
848+
# Setting recursive to False to avoid recursive error message.
849+
# All user variables will be checked anyways.
850+
for solver_variable_name in self.solver_variable_names(recursive=False):
850851
if solver_variable_name in _feature_requirement_map:
851852
if not _feature_requirement_map[solver_variable_name][0](validation_info):
852853
raise ValueError(
@@ -900,15 +901,65 @@ def user_variable_names(self):
900901
return names
901902

902903
def solver_variable_names(
903-
self, variable_type: Literal["Volume", "Surface", "Scalar", "All"] = "All"
904+
self,
905+
recursive: bool,
906+
variable_type: Literal["Volume", "Surface", "Scalar", "All"] = "All",
904907
):
905-
"""Get list of solver variable names used in expression."""
906-
expr = expr_to_model(self.expression, default_context)
907-
names = expr.used_names()
908-
names = [name for name in names if name in _solver_variables]
908+
"""Get list of solver variable names used in expression, recursively checking user variables.
909+
910+
Params:
911+
-------
912+
- variable_type: The type of variable to get the names of.
913+
- recursive: Whether to recursively check user variables for solver variables.
914+
"""
915+
916+
def _get_solver_variable_names_recursive(
917+
expression: Expression, visited: set[str], recursive: bool
918+
) -> set[str]:
919+
"""Recursively get solver variable names from expression and its user variables."""
920+
solver_names = set()
921+
922+
# Prevent infinite recursion by tracking visited expressions
923+
expr_str = str(expression)
924+
if expr_str in visited:
925+
return solver_names
926+
visited.add(expr_str)
927+
928+
# Get solver variables directly from this expression
929+
expr = expr_to_model(expression.expression, default_context)
930+
names = expr.used_names()
931+
direct_solver_names = [name for name in names if name in _solver_variables]
932+
solver_names.update(direct_solver_names)
933+
934+
if not recursive:
935+
return solver_names
936+
937+
# Get user variables from this expression and recursively check their values
938+
user_vars = expression.user_variables()
939+
for user_var in user_vars:
940+
try:
941+
if isinstance(user_var.value, Expression):
942+
# Recursively check the user variable's expression
943+
recursive_solver_names = _get_solver_variable_names_recursive(
944+
user_var.value, visited, recursive
945+
)
946+
solver_names.update(recursive_solver_names)
947+
except (ValueError, AttributeError):
948+
# Handle cases where user variable might not be properly defined
949+
pass
950+
951+
return solver_names
952+
953+
# Start the recursive search
954+
all_solver_names = _get_solver_variable_names_recursive(self, set(), recursive)
955+
956+
# Filter by variable type if specified
909957
if variable_type != "All":
910-
names = [name for name in names if _solver_variables[name] == variable_type]
911-
return names
958+
all_solver_names = {
959+
name for name in all_solver_names if _solver_variables[name] == variable_type
960+
}
961+
962+
return list(all_solver_names)
912963

913964
def to_solver_code(self, params):
914965
"""Convert to solver readable code."""

tests/simulation/test_expressions.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,39 @@ def test_udf_generator():
756756
== "double ___velocity[3];___velocity[0] = primitiveVars[1] * velocityScale;___velocity[1] = primitiveVars[2] * velocityScale;___velocity[2] = primitiveVars[3] * velocityScale;pos_vel[0] = (+___velocity[0] * 5.0); pos_vel[1] = (+___velocity[1] * 5.0); pos_vel[2] = (+___velocity[2] * 5.0);"
757757
)
758758

759+
density_kg_per_m3 = UserVariable(name="density_kg_per_m3", value=solution.density).in_units(
760+
new_unit="kg /m**3"
761+
)
762+
velocity_metric = UserVariable(name="velocity_metric", value=solution.velocity).in_units(
763+
new_unit="m/s"
764+
)
765+
mass_flow_rate_kg_per_s_per_m2 = UserVariable(
766+
name="mass_flow_rate_kg_per_s",
767+
value=math.dot(velocity_metric, solution.node_unit_normal) * density_kg_per_m3,
768+
).in_units(new_unit="kg/s/m**2")
769+
770+
assert user_variable_to_udf(mass_flow_rate_kg_per_s_per_m2, input_params=params).expression == (
771+
"double ___density;"
772+
"___density = usingLiquidAsMaterial ? 1.0 : primitiveVars[0];"
773+
"double ___node_unit_normal[3];"
774+
"double ___normalMag = magnitude(nodeNormals);"
775+
"for (int i = 0; i < 3; i++)"
776+
"{"
777+
"___node_unit_normal[i] = nodeNormals[i] / ___normalMag;"
778+
"}"
779+
"double ___velocity[3];"
780+
"___velocity[0] = primitiveVars[1] * velocityScale;"
781+
"___velocity[1] = primitiveVars[2] * velocityScale;"
782+
"___velocity[2] = primitiveVars[3] * velocityScale;"
783+
"mass_flow_rate_kg_per_s = ("
784+
"((("
785+
"(___velocity[0] * ___node_unit_normal[0]) + "
786+
"(___velocity[1] * ___node_unit_normal[1])"
787+
") + "
788+
"(___velocity[2] * ___node_unit_normal[2])"
789+
") * ___density) * 5000.0);"
790+
)
791+
759792

760793
def test_project_variables_serialization():
761794
ccc = UserVariable(name="ccc", value=12 * u.m / u.s, description="ccc description")
@@ -1751,3 +1784,159 @@ def test_correct_expression_error_location():
17511784
"operator for unyt_arrays with units 'dimensionless' (dimensions '1') and 'm' (dimensions '(length)') is not well defined."
17521785
in errors[0]["msg"]
17531786
)
1787+
1788+
1789+
def test_solver_variable_names_recursive():
1790+
"""Test the recursive solver_variable_names method with proper physical dimensions."""
1791+
1792+
# Test 1: Direct solver variable usage
1793+
expr1 = Expression(expression="solution.density + solution.pressure")
1794+
solver_vars = expr1.solver_variable_names(recursive=True)
1795+
assert set(solver_vars) == {"solution.density", "solution.pressure"}
1796+
1797+
# Test 2: No solver variables - pure mathematical expression
1798+
expr2 = Expression(expression="1.0 + 2.0 * 3.0")
1799+
solver_vars = expr2.solver_variable_names(recursive=True)
1800+
assert solver_vars == []
1801+
1802+
# Test 3: Simple user variable with solver variable (dimensionally consistent)
1803+
user_var1 = UserVariable(name="my_density", value=solution.density)
1804+
expr3 = Expression(expression="my_density * 2.0")
1805+
solver_vars = expr3.solver_variable_names(recursive=True)
1806+
assert solver_vars == ["solution.density"]
1807+
1808+
# Test 4: Nested user variables - velocity component access
1809+
user_var2 = UserVariable(name="vel_x_comp", value=solution.velocity[0])
1810+
user_var3 = UserVariable(name="scaled_vel", value=user_var2 * 2.0)
1811+
expr4 = Expression(expression="scaled_vel")
1812+
solver_vars = expr4.solver_variable_names(recursive=True)
1813+
assert solver_vars == ["solution.velocity"]
1814+
1815+
# Test 5: Multiple levels of nesting with dimensional consistency
1816+
user_var4 = UserVariable(name="rho_squared", value=user_var1 * user_var1)
1817+
user_var5 = UserVariable(name="momentum_like", value=user_var4 * user_var3)
1818+
expr5 = Expression(expression="momentum_like")
1819+
solver_vars = expr5.solver_variable_names(recursive=True)
1820+
assert set(solver_vars) == {"solution.density", "solution.velocity"}
1821+
1822+
# Test 6: Mixed direct and indirect solver variables with proper dimensions
1823+
expr6 = Expression(expression="momentum_like + solution.pressure * solution.density")
1824+
solver_vars = expr6.solver_variable_names(recursive=True)
1825+
assert set(solver_vars) == {"solution.density", "solution.velocity", "solution.pressure"}
1826+
1827+
# Test 7: User variable with dimensionless value
1828+
user_var6 = UserVariable(name="mach_number", value=0.3)
1829+
expr7 = Expression(expression="mach_number * solution.velocity[0]")
1830+
solver_vars = expr7.solver_variable_names(recursive=True)
1831+
assert solver_vars == ["solution.velocity"]
1832+
1833+
# Test 8: Filter by variable type - Volume variables only
1834+
expr8 = Expression(expression="solution.density + solution.velocity[0] + control.MachRef")
1835+
volume_vars = expr8.solver_variable_names(variable_type="Volume", recursive=True)
1836+
assert set(volume_vars) == {"solution.density", "solution.velocity"}
1837+
1838+
# Test 9: Filter by variable type - Scalar variables only
1839+
scalar_vars = expr8.solver_variable_names(variable_type="Scalar", recursive=True)
1840+
assert scalar_vars == ["control.MachRef"]
1841+
1842+
# Test 10: Complex flow physics expression with proper dimensions
1843+
user_var7 = UserVariable(
1844+
name="dyn_press", value=0.5 * solution.density * solution.velocity[0] * solution.velocity[0]
1845+
)
1846+
user_var8 = UserVariable(name="tot_press", value=solution.pressure + user_var7)
1847+
expr9 = Expression(expression="tot_press")
1848+
solver_vars = expr9.solver_variable_names(recursive=True)
1849+
assert set(solver_vars) == {"solution.density", "solution.velocity", "solution.pressure"}
1850+
1851+
# Test 11: Temperature-based expressions (for compressible flow)
1852+
user_var9 = UserVariable(
1853+
name="temp_ratio", value=solution.temperature / 300.0
1854+
) # Reference temperature
1855+
user_var10 = UserVariable(name="scaled_density", value=solution.density * user_var9)
1856+
expr10 = Expression(expression="scaled_density")
1857+
solver_vars = expr10.solver_variable_names(recursive=True)
1858+
assert set(solver_vars) == {"solution.temperature", "solution.density"}
1859+
1860+
# Test 12: Vector operations with proper indexing
1861+
user_var11 = UserVariable(
1862+
name="vel_mag_sq",
1863+
value=solution.velocity[0] * solution.velocity[0]
1864+
+ solution.velocity[1] * solution.velocity[1]
1865+
+ solution.velocity[2] * solution.velocity[2],
1866+
)
1867+
expr11 = Expression(expression="vel_mag_sq")
1868+
solver_vars = expr11.solver_variable_names(recursive=True)
1869+
assert solver_vars == ["solution.velocity"]
1870+
1871+
# Test 13: Deep nesting with multiple physics variables
1872+
user_var12 = UserVariable(name="ke_calc", value=0.5 * solution.density * user_var11)
1873+
user_var13 = UserVariable(name="te_calc", value=user_var12 + solution.pressure / (1.4 - 1.0))
1874+
user_var14 = UserVariable(name="epv_calc", value=user_var13 / solution.temperature)
1875+
expr12 = Expression(expression="epv_calc")
1876+
solver_vars = expr12.solver_variable_names(recursive=True)
1877+
expected_vars = {
1878+
"solution.density",
1879+
"solution.velocity",
1880+
"solution.pressure",
1881+
"solution.temperature",
1882+
}
1883+
assert set(solver_vars) == expected_vars
1884+
1885+
# Test 14: Mathematical functions with solver variables (using dimensionless ratios)
1886+
user_var15 = UserVariable(
1887+
name="rho_ratio", value=solution.density / solution.density
1888+
) # Dimensionless ratio
1889+
user_var16 = UserVariable(
1890+
name="complex_func", value=Expression(expression="math.exp(rho_ratio)")
1891+
)
1892+
expr13 = Expression(expression="complex_func")
1893+
solver_vars = expr13.solver_variable_names(recursive=True)
1894+
assert solver_vars == ["solution.density"]
1895+
1896+
# Test 15: Control variables in time-dependent expressions
1897+
user_var17 = UserVariable(
1898+
name="time_scaled_vel", value=solution.velocity[0] * control.timeStepSize
1899+
)
1900+
expr14 = Expression(expression="time_scaled_vel") # Just use the time-scaled velocity
1901+
solver_vars = expr14.solver_variable_names(recursive=True)
1902+
assert set(solver_vars) == {"solution.velocity", "control.timeStepSize"}
1903+
1904+
# Test 16: Edge case - circular reference prevention
1905+
user_var18 = UserVariable(name="base_var", value=solution.pressure)
1906+
user_var19 = UserVariable(name="derived_var", value=user_var18 * 2.0)
1907+
user_var20 = UserVariable(name="twice_derived", value=user_var19 + user_var18)
1908+
expr15 = Expression(expression="twice_derived")
1909+
solver_vars = expr15.solver_variable_names(recursive=True)
1910+
assert solver_vars == ["solution.pressure"]
1911+
1912+
# Test 17: Mixed dimensioned and dimensionless variables
1913+
user_var21 = UserVariable(name="re_num", value=1e6) # Dimensionless
1914+
user_var22 = UserVariable(
1915+
name="char_vel", value=user_var21 * solution.mut / (solution.density * 1.0)
1916+
) # Length = 1.0 m
1917+
expr16 = Expression(expression="char_vel")
1918+
solver_vars = expr16.solver_variable_names(recursive=True)
1919+
assert set(solver_vars) == {"solution.mut", "solution.density"}
1920+
1921+
# Test 18: Surface variables if available
1922+
try:
1923+
# Only test if surface variables exist
1924+
expr17 = Expression(expression="solution.Cp + solution.density")
1925+
solver_vars = expr17.solver_variable_names(variable_type="Surface", recursive=True)
1926+
# Check if surface variables are found
1927+
surface_vars = [var for var in solver_vars if "Cp" in var]
1928+
if surface_vars:
1929+
assert "solution.Cp" in solver_vars
1930+
except (AttributeError, ValueError):
1931+
# Skip if surface variables not available in test environment
1932+
pass
1933+
1934+
# Test 19: All variable types combined
1935+
expr18 = Expression(expression="solution.density + solution.velocity[0] + control.MachRef")
1936+
all_vars = expr18.solver_variable_names(variable_type="All", recursive=True)
1937+
assert set(all_vars) == {"solution.density", "solution.velocity", "control.MachRef"}
1938+
1939+
# Test 20: Simple expression with just a constant
1940+
expr19 = Expression(expression="42.0")
1941+
solver_vars = expr19.solver_variable_names(recursive=True)
1942+
assert solver_vars == []

0 commit comments

Comments
 (0)