Skip to content

Commit efa43c0

Browse files
[Fea&Fix] Support register dataset and fix several bugs (#1207)
* support register custom dataset to ppsci.data.dataset * fix typing for cfg * fix typing for cfg * fix several kernel not supported in xpu * fix drivaernet datasets
1 parent a8fd230 commit efa43c0

File tree

21 files changed

+180
-37
lines changed

21 files changed

+180
-37
lines changed

competition/IJCAI_2024_CAR

Submodule IJCAI_2024_CAR updated from 4e12f69 to 24adee6

ppsci/arch/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import copy
18+
from typing import TYPE_CHECKING
1819

1920
from ppsci.arch.afno import AFNONet # isort:skip
2021
from ppsci.arch.afno import PrecipNet # isort:skip
@@ -64,6 +65,10 @@
6465
from ppsci.arch.ifm_mlp import IFMMLP # isort:skip
6566
from ppsci.arch.stafnet import STAFNet # isort:skip
6667

68+
if TYPE_CHECKING:
69+
from omegaconf import DictConfig
70+
71+
6772
__all__ = [
6873
"MoFlowNet",
6974
"MoFlowProp",
@@ -118,7 +123,7 @@
118123
]
119124

120125

121-
def build_model(cfg):
126+
def build_model(cfg: DictConfig):
122127
"""Build model
123128
124129
Args:

ppsci/arch/afno.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ def __init__(
203203
default_initializer=nn.initializer.Normal(std=self.scale),
204204
)
205205

206+
if paddle.device.is_compiled_with_xpu():
207+
# softshrink is not supported in xpu, so we use a custom implementation to avoid fallback to cpu
208+
def softshrink(x: paddle.Tensor, threshold: float = 0.5) -> paddle.Tensor:
209+
return paddle.where(
210+
x > threshold,
211+
x - threshold,
212+
paddle.where(x < -threshold, x + threshold, paddle.zeros_like(x)),
213+
)
214+
215+
self.softshrink = softshrink
216+
else:
217+
self.softshrink = F.softshrink
218+
206219
def forward(self, x):
207220
bias = x
208221

@@ -285,7 +298,7 @@ def forward(self, x):
285298
)
286299

287300
x = paddle.stack([o2_real, o2_imag], axis=-1)
288-
x = F.softshrink(x, threshold=self.sparsity_threshold)
301+
x = self.softshrink(x, threshold=self.sparsity_threshold)
289302
x = paddle.as_complex(x)
290303
x = x.reshape((B, H, W // 2 + 1, C))
291304
x = paddle.fft.irfft2(x, s=(H, W), axes=(1, 2), norm="ortho")

ppsci/constraint/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def build_constraint(cfg, equation_dict, geom_dict):
4242
"""Build constraint(s).
4343
4444
Args:
45-
cfg (List[DictConfig]): Constraint config list.
45+
cfg (DictConfig): Constraint config list.
4646
equation_dict (Dct[str, Equation]): Equation(s) in dict.
4747
geom_dict (Dct[str, Geometry]): Geometry(ies) in dict.
4848

ppsci/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ppsci.data import dataloader
2727
from ppsci.data import dataset
2828
from ppsci.data import process
29+
from ppsci.data.dataset import register_to_dataset
2930
from ppsci.data.process import batch_transform
3031
from ppsci.data.process import transform
3132
from ppsci.utils import logger
@@ -37,6 +38,7 @@
3738
"build_dataloader",
3839
"transform",
3940
"batch_transform",
41+
"register_to_dataset",
4042
]
4143

4244

ppsci/data/dataset/__init__.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import copy
18+
import sys
1619
from typing import TYPE_CHECKING
1720

21+
from paddle import io
22+
1823
from ppsci.data.dataset.airfoil_dataset import MeshAirfoilDataset
1924
from ppsci.data.dataset.array_dataset import ChipHeatDataset
2025
from ppsci.data.dataset.array_dataset import ContinuousNamedArrayDataset
@@ -56,7 +61,7 @@
5661
from ppsci.utils import logger
5762

5863
if TYPE_CHECKING:
59-
from paddle import io
64+
from omegaconf import DictConfig
6065

6166
__all__ = [
6267
"IterableNamedArrayDataset",
@@ -97,14 +102,15 @@
97102
"IFMMoeDataset",
98103
"STAFNetDataset",
99104
"TMTDataset",
105+
"register_to_dataset",
100106
]
101107

102108

103-
def build_dataset(cfg) -> "io.Dataset":
109+
def build_dataset(cfg: DictConfig) -> "io.Dataset":
104110
"""Build dataset
105111
106112
Args:
107-
cfg (List[DictConfig]): Dataset config list.
113+
cfg (DictConfig): Dataset config list.
108114
109115
Returns:
110116
Dict[str, io.Dataset]: dataset.
@@ -115,8 +121,37 @@ def build_dataset(cfg) -> "io.Dataset":
115121
if "transforms" in cfg:
116122
cfg["transforms"] = transform.build_transforms(cfg.pop("transforms"))
117123

118-
dataset = eval(dataset_cls)(**cfg)
124+
try:
125+
dataset = eval(dataset_cls)(**cfg)
126+
except NameError:
127+
import textwrap
128+
129+
logger.error(
130+
f"name {dataset_cls} is not defined, maybe you should register your dataset class first as below:\n"
131+
+ textwrap.indent(
132+
"\n"
133+
"import paddle\n"
134+
"from ppsci.data import register_to_dataset\n"
135+
"\n"
136+
"@register_to_dataset\n"
137+
"class MyDataset(paddle.io.Dataset):\n"
138+
" pass\n"
139+
"\n",
140+
prefix=" " * 4,
141+
)
142+
)
143+
raise
119144

120145
logger.debug(str(dataset))
121146

122147
return dataset
148+
149+
150+
def register_to_dataset(cls: type):
151+
from ppsci.utils.registry import register_cls_to_module
152+
153+
if not issubclass(cls, io.Dataset):
154+
logger.warning(
155+
f"The registered class '{cls.__name__}' should be inherited from `paddle.io.Dataset`"
156+
)
157+
register_cls_to_module(sys.modules[__name__], cls)

ppsci/data/dataset/drivaernet_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def _load_point_cloud(self, design_id: str) -> Optional[paddle.Tensor]:
245245
load_path = os.path.join(self.root_dir, f"{design_id}.paddle_tensor")
246246
if os.path.exists(load_path) and os.path.getsize(load_path) > 0:
247247
try:
248-
vertices = paddle.load(path=str(load_path))
248+
vertices: paddle.Tensor = paddle.load(path=str(load_path))
249249
num_vertices = vertices.shape[0]
250250

251251
if num_vertices > self.num_points:
@@ -255,6 +255,8 @@ def _load_point_cloud(self, design_id: str) -> Optional[paddle.Tensor]:
255255
vertices = vertices.numpy()[indices]
256256
vertices = paddle.to_tensor(vertices)
257257

258+
vertices = self._sample_or_pad_vertices(vertices, self.num_points)
259+
258260
return vertices
259261
except (EOFError, RuntimeError, ValueError) as e:
260262
raise Exception(

ppsci/data/dataset/drivaernetplusplus_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _load_point_cloud(self, design_id: str):
246246
load_path = os.path.join(self.root_dir, f"{design_id}.paddle_tensor")
247247
if os.path.exists(load_path) and os.path.getsize(load_path) > 0:
248248
try:
249-
vertices = paddle.load(path=str(load_path))
249+
vertices: paddle.Tensor = paddle.load(path=str(load_path))
250250
except (EOFError, RuntimeError, ValueError) as e:
251251
raise Exception(
252252
f"Error loading point cloud from {load_path}: {e}"
@@ -256,6 +256,9 @@ def _load_point_cloud(self, design_id: str):
256256
if num_vertices > self.num_points:
257257
indices = np.random.choice(num_vertices, self.num_points, replace=False)
258258
vertices = vertices.numpy()[indices]
259+
vertices = paddle.to_tensor(vertices)
260+
261+
vertices = self._sample_or_pad_vertices(vertices, self.num_points)
259262

260263
return vertices
261264

ppsci/data/process/batch_transform/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import copy
1618
import numbers
1719
from collections.abc import Mapping
1820
from collections.abc import Sequence
21+
from typing import TYPE_CHECKING
1922
from typing import Any
2023
from typing import Callable
2124
from typing import List
@@ -27,6 +30,10 @@
2730
from ppsci.data.process import transform
2831
from ppsci.data.process.batch_transform.preprocess import FunctionalBatchTransform
2932

33+
if TYPE_CHECKING:
34+
from omegaconf import DictConfig
35+
36+
3037
try:
3138
import pgl
3239
except ModuleNotFoundError:
@@ -104,7 +111,7 @@ def default_collate_fn(batch: List[Any]) -> Any:
104111
)
105112

106113

107-
def build_transforms(cfg):
114+
def build_transforms(cfg: DictConfig):
108115
if not cfg:
109116
return transform.Compose([])
110117
cfg = copy.deepcopy(cfg)

ppsci/data/process/transform/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import copy
1618
import traceback
19+
from typing import TYPE_CHECKING
1720
from typing import Any
1821
from typing import Tuple
1922

2023
from paddle import vision
2124

25+
if TYPE_CHECKING:
26+
from omegaconf import DictConfig
27+
2228
from ppsci.data.process.transform.preprocess import CropData
2329
from ppsci.data.process.transform.preprocess import FunctionalTransform
2430
from ppsci.data.process.transform.preprocess import Log1p
@@ -57,7 +63,7 @@ def __call__(self, *data: Tuple[Any, ...]):
5763
return data
5864

5965

60-
def build_transforms(cfg):
66+
def build_transforms(cfg: DictConfig):
6167
if not cfg:
6268
return Compose([])
6369
cfg = copy.deepcopy(cfg)

0 commit comments

Comments
 (0)