Skip to content

Commit 6c79832

Browse files
[Fix] Fix TimeXGeometry wrong data genearation when given time_step (#1178)
* fix timeXgeometry wrong time stamps genearation when given time_step * add test for geom which do not has 'sdf_func'
1 parent 5ea90ae commit 6c79832

File tree

4 files changed

+120
-72
lines changed

4 files changed

+120
-72
lines changed

ppsci/geometry/geometry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def sample_interior(
134134
) -> Dict[str, np.ndarray]:
135135
"""Sample random points in the geometry and return those meet criteria.
136136
137+
NOTE: sdf values returned by this function are negated because the weight in
138+
loss function should be positive.
139+
137140
Args:
138141
n (int): Number of points.
139142
random (Literal["pseudo", "Halton", "LHS"]): Random method. Defaults to "pseudo".
@@ -211,6 +214,7 @@ def sample_interior(
211214

212215
# if sdf_func added, return x_dict and sdf_dict, else, only return the x_dict
213216
if hasattr(self, "sdf_func"):
217+
# NOTE: add negative to the sdf values because weight should be positive.
214218
sdf = -self.sdf_func(x)
215219
sdf_dict = misc.convert_to_dict(sdf, ("sdf",))
216220
sdf_derives_dict = {}

ppsci/geometry/mesh.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,12 @@ def sample_interior(
582582
evenly: bool = False,
583583
compute_sdf_derivatives: bool = False,
584584
):
585-
"""Sample random points in the geometry and return those meet criteria."""
585+
"""
586+
Sample random points in the geometry and return those meet criteria.
587+
588+
NOTE: sdf values returned by this function are negated because the weight in
589+
loss function should be positive.
590+
"""
586591
if evenly:
587592
# TODO(sensen): Implement uniform sample for mesh interior.
588593
raise NotImplementedError(
@@ -1112,7 +1117,12 @@ def sample_interior(
11121117
evenly: bool = False,
11131118
compute_sdf_derivatives: bool = False,
11141119
):
1115-
"""Sample random points in the geometry and return those meet criteria."""
1120+
"""
1121+
Sample random points in the geometry and return those meet criteria.
1122+
1123+
NOTE: sdf values returned by this function are negated because the weight in
1124+
loss function should be positive.
1125+
"""
11161126
if evenly:
11171127
# TODO(sensen): Implement uniform sample for mesh interior.
11181128
raise NotImplementedError(
@@ -1128,7 +1138,7 @@ def sample_interior(
11281138
if compute_sdf_derivatives:
11291139
sdf, sdf_derives = sdf
11301140

1131-
# NOTE: Negate sdf because weight should be positive.
1141+
# NOTE: add negative to the sdf values because weight should be positive.
11321142
sdf_dict = misc.convert_to_dict(-sdf, ("sdf",))
11331143

11341144
sdf_derives_dict = {}

ppsci/geometry/timedomain.py

Lines changed: 55 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import annotations
2020

2121
import itertools
22+
import types
2223
from typing import Callable
2324
from typing import Dict
2425
from typing import Optional
@@ -61,19 +62,23 @@ def __init__(
6162
super().__init__(t0, t1)
6263
self.t0 = t0
6364
self.t1 = t1
65+
assert (
66+
time_step is None or timestamps is None
67+
), "time_step and timestamps cannot be both set."
6468
self.time_step = time_step
65-
if timestamps is None:
66-
self.timestamps = None
67-
else:
69+
if timestamps is not None:
6870
self.timestamps = np.array(
6971
timestamps, dtype=paddle.get_default_dtype()
7072
).reshape([-1])
71-
if time_step is not None:
72-
if time_step <= 0:
73-
raise ValueError(f"time_step({time_step}) must be larger than 0.")
74-
self.num_timestamps = int(np.ceil((t1 - t0) / time_step)) + 1
75-
elif timestamps is not None:
76-
self.num_timestamps = len(timestamps)
73+
self.num_timestamps = len(self.timestamps)
74+
elif time_step is not None:
75+
# set timestamps manually with given time_step
76+
self.timestamps = np.arange(
77+
t0, t1, time_step, dtype=paddle.get_default_dtype()
78+
)
79+
self.num_timestamps = len(self.timestamps)
80+
else:
81+
self.timestamps = None
7782

7883
def on_initial(self, t: np.ndarray) -> np.ndarray:
7984
"""Check if a specific time is on the initial time point.
@@ -115,6 +120,35 @@ def __init__(self, timedomain: TimeDomain, geometry: geometry.Geometry):
115120
self.geometry = geometry
116121
self.ndim = geometry.ndim + timedomain.ndim
117122

123+
if hasattr(self.geometry, "sdf_func"):
124+
125+
def sdf_func(self, points: np.ndarray) -> np.ndarray:
126+
"""Compute signed distance field.
127+
128+
Args:
129+
points (np.ndarray): The temporal-spatial coordinate points used to calculate
130+
the SDF value, the shape is [N, 1+D], where 1 represents the temporal
131+
dimension and D represents the spatial dimensions.
132+
133+
Returns:
134+
np.ndarray: SDF values of input points without squared, the shape is [N, 1].
135+
136+
NOTE: This function usually returns ndarray with negative values, because
137+
according to the definition of SDF, the SDF value of the coordinate point inside
138+
the object(interior points) is negative, the outside is positive, and the edge
139+
is 0. Therefore, when used for weighting, a negative sign is often added before
140+
the result of this function.
141+
"""
142+
if points.shape[1] != self.ndim:
143+
raise ValueError(
144+
f"Shape of given points should be [*, {self.ndim}], but got {points.shape}"
145+
)
146+
spatial_points = points[:, 1:]
147+
sdf = self.geometry.sdf_func(spatial_points)
148+
return sdf
149+
150+
self.sdf_func = types.MethodType(sdf_func, self)
151+
118152
@property
119153
def dim_keys(self):
120154
return ("t",) + self.geometry.dim_keys
@@ -153,11 +187,10 @@ def uniform_points(self, n: int, boundary: bool = True) -> np.ndarray:
153187
>>> print(ts.shape)
154188
(1000, 3)
155189
"""
156-
if self.timedomain.time_step is not None:
157-
# exclude start time t0
158-
nt = int(np.ceil(self.timedomain.diam / self.timedomain.time_step))
159-
nx = int(np.ceil(n / nt))
160-
elif self.timedomain.timestamps is not None:
190+
if (
191+
self.timedomain.time_step is not None
192+
or self.timedomain.timestamps is not None
193+
):
161194
# exclude start time t0
162195
nt = self.timedomain.num_timestamps - 1
163196
nx = int(np.ceil(n / nt))
@@ -211,7 +244,8 @@ def random_points(
211244
criteria (Optional[Callable]): A method that filters on the generated random points. Defaults to None.
212245
213246
Returns:
214-
np.ndarray: A set of random spatial-temporal points.
247+
np.ndarray: A array of random spatial-temporal points with shape [N, 1+D], where 1 represents the
248+
temporal dimension and D represents the spatial dimensions.
215249
216250
Examples:
217251
>>> import ppsci
@@ -225,63 +259,14 @@ def random_points(
225259
if self.timedomain.time_step is None and self.timedomain.timestamps is None:
226260
raise ValueError("Either time_step or timestamps must be provided.")
227261
# time evenly and geometry random, if time_step if specified
228-
if self.timedomain.time_step is not None:
229-
nt = int(np.ceil(self.timedomain.diam / self.timedomain.time_step))
230-
t = np.linspace(
231-
self.timedomain.t1,
232-
self.timedomain.t0,
233-
num=nt,
234-
endpoint=False,
235-
dtype=paddle.get_default_dtype(),
236-
)[:, None][
237-
::-1
238-
] # [nt, 1]
262+
if (
263+
self.timedomain.time_step is not None
264+
or self.timedomain.timestamps is not None
265+
):
239266
# 1. sample nx points in static geometry with criteria
240-
nx = int(np.ceil(n / nt))
241-
_size, _ntry, _nsuc = 0, 0, 0
242-
x = np.empty(
243-
shape=(nx, self.geometry.ndim), dtype=paddle.get_default_dtype()
244-
)
245-
while _size < nx:
246-
_x = self.geometry.random_points(nx, random)
247-
if criteria is not None:
248-
# fix arg 't' to None in criteria there
249-
criteria_mask = criteria(
250-
None, *np.split(_x, self.geometry.ndim, axis=1)
251-
).flatten()
252-
_x = _x[criteria_mask]
253-
if len(_x) > nx - _size:
254-
_x = _x[: nx - _size]
255-
x[_size : _size + len(_x)] = _x
256-
257-
_size += len(_x)
258-
_ntry += 1
259-
if len(_x) > 0:
260-
_nsuc += 1
261-
262-
if _ntry >= 1000 and _nsuc == 0:
263-
raise ValueError(
264-
"Sample points failed, "
265-
"please check correctness of geometry and given criteria."
266-
)
267-
268-
# 2. repeat spatial points along time
269-
tx = []
270-
for ti in t:
271-
tx.append(
272-
np.hstack(
273-
(np.full([nx, 1], ti, dtype=paddle.get_default_dtype()), x)
274-
)
275-
)
276-
tx = np.vstack(tx)
277-
if len(tx) > n:
278-
tx = tx[:n]
279-
return tx
280-
elif self.timedomain.timestamps is not None:
281-
nt = self.timedomain.num_timestamps - 1
282267
t = self.timedomain.timestamps[1:]
268+
nt = self.timedomain.num_timestamps - 1
283269
nx = int(np.ceil(n / nt))
284-
285270
_size, _ntry, _nsuc = 0, 0, 0
286271
x = np.empty(
287272
shape=(nx, self.geometry.ndim), dtype=paddle.get_default_dtype()
@@ -309,6 +294,7 @@ def random_points(
309294
"please check correctness of geometry and given criteria."
310295
)
311296

297+
# 2. repeat spatial points along time
312298
tx = []
313299
for ti in t:
314300
tx.append(

test/geometry/test_timedomain_sdf.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
from ppsci import geometry
19+
20+
21+
def test_sdf_of_TimeXGeometry():
22+
timedomain = geometry.TimeDomain(0, 1, time_step=0.3)
23+
geom = geometry.Rectangle((0, 0), (1, 1))
24+
time_geom = geometry.TimeXGeometry(timedomain, geom)
25+
26+
interior_points = time_geom.sample_interior(
27+
timedomain.num_timestamps * 3000, compute_sdf_derivatives=True
28+
)
29+
30+
assert "sdf" in interior_points
31+
assert "sdf__x" in interior_points
32+
assert "sdf__y" in interior_points
33+
34+
interior_points = {"x": np.linspace(-1, 1, dtype="float32").reshape((-1, 1))}
35+
geom = geometry.PointCloud(interior_points, ("x",))
36+
time_geom = geometry.TimeXGeometry(timedomain, geom)
37+
38+
interior_points = time_geom.sample_interior(
39+
timedomain.num_timestamps * 10, compute_sdf_derivatives=True
40+
)
41+
42+
assert "sdf" not in interior_points
43+
assert "sdf__x" not in interior_points
44+
assert "sdf__y" not in interior_points
45+
46+
47+
if __name__ == "__main__":
48+
pytest.main()

0 commit comments

Comments
 (0)