@@ -46,7 +46,6 @@ def _load_pretrain_from_path(
46
46
path : str ,
47
47
model : nn .Layer ,
48
48
equation : Optional [Dict [str , equation .PDE ]] = None ,
49
- loss_aggregator : Optional [mtl .LossAggregator ] = None ,
50
49
):
51
50
"""Load pretrained model from given path.
52
51
@@ -81,26 +80,11 @@ def _load_pretrain_from_path(
81
80
f"Finish loading pretrained equation parameters from: { path } .pdeqn"
82
81
)
83
82
84
- if loss_aggregator is not None :
85
- if not os .path .exists (f"{ path } .pdagg" ):
86
- if loss_aggregator .should_persist :
87
- logger .warning (
88
- f"Given loss_aggregator({ type (loss_aggregator )} ) has persistable"
89
- f"parameters or buffers, but { path } .pdagg not found."
90
- )
91
- else :
92
- aggregator_dict = paddle .load (f"{ path } .pdagg" )
93
- loss_aggregator .set_state_dict (aggregator_dict )
94
- logger .message (
95
- f"Finish loading pretrained equation parameters from: { path } .pdagg"
96
- )
97
-
98
83
99
84
def load_pretrain (
100
85
model : nn .Layer ,
101
86
path : str ,
102
87
equation : Optional [Dict [str , equation .PDE ]] = None ,
103
- loss_aggregator : Optional [mtl .LossAggregator ] = None ,
104
88
):
105
89
"""
106
90
Load pretrained model from given path or url.
@@ -142,7 +126,7 @@ def is_url_accessible(url: str):
142
126
# remove ".pdparams" in suffix of path for convenient
143
127
if path .endswith (".pdparams" ):
144
128
path = path [:- 9 ]
145
- _load_pretrain_from_path (path , model , equation , loss_aggregator )
129
+ _load_pretrain_from_path (path , model , equation )
146
130
147
131
148
132
def load_checkpoint (
0 commit comments