Skip to content

Commit 2d636cb

Browse files
[API] Support sdf weight for initial constraint (#1184)
* support sdf weight for initial constraint * add field check
1 parent 4a1f717 commit 2d636cb

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

ppsci/constraint/initial_constraint.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,16 @@ def __init__(
139139
if weight_dict is not None:
140140
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
141141
for key, value in weight_dict.items():
142-
if isinstance(value, (int, float)):
142+
if isinstance(value, str):
143+
if value == "sdf":
144+
if "sdf" not in input:
145+
raise ValueError(
146+
f"Missing 'sdf' field in input. Please check whether the geometry ({geom.geometry.__class__.__name__}) implements 'sdf_func'"
147+
)
148+
weight[key] = input["sdf"]
149+
else:
150+
raise NotImplementedError(f"string {value} is invalid yet.")
151+
elif isinstance(value, (int, float)):
143152
weight[key] = np.full_like(next(iter(label.values())), value)
144153
elif isinstance(value, sympy.Basic):
145154
func = sympy.lambdify(

ppsci/constraint/interior_constraint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def __init__(
138138
for key, value in weight_dict.items():
139139
if isinstance(value, str):
140140
if value == "sdf":
141+
if "sdf" not in input:
142+
raise ValueError(
143+
f"Missing 'sdf' field in input. Please check whether the geometry ({geom.__class__.__name__}) implements 'sdf_func'"
144+
)
141145
weight[key] = input["sdf"]
142146
else:
143147
raise NotImplementedError(f"string {value} is invalid yet.")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from ppsci import constraint
18+
from ppsci import geometry
19+
from ppsci import loss
20+
21+
22+
def test_sdf_weight_of_initialconstraint():
23+
rect = geometry.TimeXGeometry(
24+
geometry.TimeDomain(0, 1),
25+
geometry.Rectangle((0, 0), (1, 1)),
26+
)
27+
ic = constraint.InitialConstraint(
28+
{"u": lambda out: out["u"]},
29+
{"u": 0},
30+
rect,
31+
{
32+
"dataset": "IterableNamedArrayDataset",
33+
"iters_per_epoch": 1,
34+
"batch_size": 16,
35+
},
36+
loss.MSELoss("mean"),
37+
weight_dict={"u": "sdf"},
38+
name="IC",
39+
) # doctest: +SKIP
40+
41+
input, _, _ = next(iter(ic.data_iter))
42+
assert "sdf" in input.keys()
43+
44+
45+
if __name__ == "__main__":
46+
pytest.main()

0 commit comments

Comments
 (0)