Skip to content

Commit eb691ff

Browse files
[Doc] Add HKUST badge and remove some codes (#950)
* add HKUST icon in cooperation institution * remove unnecessary loss_aggregator in load_pretrain function * resize image * add requests into requirements
1 parent 779048d commit eb691ff

File tree

3 files changed

+2
-17
lines changed

3 files changed

+2
-17
lines changed

docs/images/overview/cooperation.png

351 KB
Loading

ppsci/utils/save_load.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def _load_pretrain_from_path(
4646
path: str,
4747
model: nn.Layer,
4848
equation: Optional[Dict[str, equation.PDE]] = None,
49-
loss_aggregator: Optional[mtl.LossAggregator] = None,
5049
):
5150
"""Load pretrained model from given path.
5251
@@ -81,26 +80,11 @@ def _load_pretrain_from_path(
8180
f"Finish loading pretrained equation parameters from: {path}.pdeqn"
8281
)
8382

84-
if loss_aggregator is not None:
85-
if not os.path.exists(f"{path}.pdagg"):
86-
if loss_aggregator.should_persist:
87-
logger.warning(
88-
f"Given loss_aggregator({type(loss_aggregator)}) has persistable"
89-
f"parameters or buffers, but {path}.pdagg not found."
90-
)
91-
else:
92-
aggregator_dict = paddle.load(f"{path}.pdagg")
93-
loss_aggregator.set_state_dict(aggregator_dict)
94-
logger.message(
95-
f"Finish loading pretrained equation parameters from: {path}.pdagg"
96-
)
97-
9883

9984
def load_pretrain(
10085
model: nn.Layer,
10186
path: str,
10287
equation: Optional[Dict[str, equation.PDE]] = None,
103-
loss_aggregator: Optional[mtl.LossAggregator] = None,
10488
):
10589
"""
10690
Load pretrained model from given path or url.
@@ -142,7 +126,7 @@ def is_url_accessible(url: str):
142126
# remove ".pdparams" in suffix of path for convenient
143127
if path.endswith(".pdparams"):
144128
path = path[:-9]
145-
_load_pretrain_from_path(path, model, equation, loss_aggregator)
129+
_load_pretrain_from_path(path, model, equation)
146130

147131

148132
def load_checkpoint(

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ numpy>=1.20.0,<=1.23.1
99
pydantic>=2.5.0
1010
pyevtk
1111
pyyaml
12+
requests
1213
scikit-learn<1.5.0
1314
scikit-optimize
1415
scipy

0 commit comments

Comments
 (0)