Skip to content

Commit d833a40

Browse files
[Fix] Fix batch indexing failed in phylstm2 (#941)
* fix batch indexing failed in list: * fix chapter number of adv_cvit.md(test=document_fix)
1 parent 3a1d360 commit d833a40

File tree

4 files changed

+31
-35
lines changed

4 files changed

+31
-35
lines changed

docs/zh/examples/adv_cvit.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ examples/adv/adv_cvit.py:117:125
145145
--8<--
146146
```
147147

148-
### 3.7 模型训练、评估
148+
### 3.6 模型训练、评估
149149

150150
完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练、评估。
151151

docs/zh/examples/phylstm.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,19 @@ examples/phylstm/phylstm2.py:37:100
106106

107107
设置训练数据集和损失计算函数,返回字段,代码如下所示:
108108

109-
``` py linenums="119"
109+
``` py linenums="120"
110110
--8<--
111-
examples/phylstm/phylstm2.py:119:145
111+
examples/phylstm/phylstm2.py:120:146
112112
--8<--
113113
```
114114

115115
### 3.4 评估器构建
116116

117117
设置评估数据集和损失计算函数,返回字段,代码如下所示:
118118

119-
``` py linenums="147"
119+
``` py linenums="148"
120120
--8<--
121-
examples/phylstm/phylstm2.py:147:174
121+
examples/phylstm/phylstm2.py:148:170
122122
--8<--
123123
```
124124

@@ -136,27 +136,27 @@ examples/phylstm/conf/phylstm2.yaml:39:39
136136

137137
训练过程会调用优化器来更新模型参数,此处选择 `Adam` 优化器并设定 `learning_rate` 为 1e-3。
138138

139-
``` py linenums="177"
139+
``` py linenums="172"
140140
--8<--
141-
examples/phylstm/phylstm2.py:177:177
141+
examples/phylstm/phylstm2.py:172:173
142142
--8<--
143143
```
144144

145145
### 3.7 模型训练与评估
146146

147147
完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`
148148

149-
``` py linenums="178"
149+
``` py linenums="174"
150150
--8<--
151-
examples/phylstm/phylstm2.py:178:192
151+
examples/phylstm/phylstm2.py:174:180
152152
--8<--
153153
```
154154

155155
最后启动训练、评估即可:
156156

157-
``` py linenums="194"
157+
``` py linenums="182"
158158
--8<--
159-
examples/phylstm/phylstm2.py:194:197
159+
examples/phylstm/phylstm2.py:182:185
160160
--8<--
161161
```
162162

examples/phylstm/functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,19 @@ def get(self, epochs=1):
181181
np.asarray(0.0, dtype=paddle.get_default_dtype())
182182
)
183183

184+
def to_numpy_dict(dct):
185+
return {k: np.asarray(v, dtype="float32") for k, v in dct.items()}
186+
187+
input_dict_train = to_numpy_dict(input_dict_train)
188+
for k, v in input_dict_train.items():
189+
print(f"input_dict_train {k} {type(v)}")
190+
label_dict_train = to_numpy_dict(label_dict_train)
191+
for k, v in label_dict_train.items():
192+
print(f"label_dict_train {k} {type(v)}")
193+
input_dict_val = to_numpy_dict(input_dict_val)
194+
for k, v in input_dict_val.items():
195+
print(f"input_dict_val {k} {type(v)}")
196+
label_dict_val = to_numpy_dict(label_dict_val)
197+
for k, v in label_dict_val.items():
198+
print(f"label_dict_val {k} {type(v)}")
184199
return input_dict_train, label_dict_train, input_dict_val, label_dict_val

examples/phylstm/phylstm2.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def train(cfg: DictConfig):
109109
model.register_output_transform(functions.transform_out)
110110

111111
dataset_obj = functions.Dataset(eta, eta_t, g, ag, ag_c, lift, phi_t)
112+
112113
(
113114
input_dict_train,
114115
label_dict_train,
@@ -151,11 +152,6 @@ def train(cfg: DictConfig):
151152
"input": input_dict_val,
152153
"label": label_dict_val,
153154
},
154-
"sampler": {
155-
"name": "BatchSampler",
156-
"drop_last": False,
157-
"shuffle": False,
158-
},
159155
"batch_size": 1,
160156
"num_workers": 0,
161157
},
@@ -178,17 +174,9 @@ def train(cfg: DictConfig):
178174
solver = ppsci.solver.Solver(
179175
model,
180176
constraint_pde,
181-
cfg.output_dir,
182-
optimizer,
183-
None,
184-
cfg.TRAIN.epochs,
185-
cfg.TRAIN.iters_per_epoch,
186-
save_freq=cfg.TRAIN.save_freq,
187-
log_freq=cfg.log_freq,
188-
seed=cfg.seed,
177+
optimizer=optimizer,
189178
validator=validator_pde,
190-
checkpoint_path=cfg.TRAIN.checkpoint_path,
191-
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
179+
cfg=cfg,
192180
)
193181

194182
# train model
@@ -278,6 +266,7 @@ def evaluate(cfg: DictConfig):
278266
model.register_output_transform(functions.transform_out)
279267

280268
dataset_obj = functions.Dataset(eta, eta_t, g, ag, ag_c, lift, phi_t)
269+
281270
(
282271
_,
283272
_,
@@ -292,11 +281,6 @@ def evaluate(cfg: DictConfig):
292281
"input": input_dict_val,
293282
"label": label_dict_val,
294283
},
295-
"sampler": {
296-
"name": "BatchSampler",
297-
"drop_last": False,
298-
"shuffle": False,
299-
},
300284
"batch_size": 1,
301285
"num_workers": 0,
302286
},
@@ -317,11 +301,8 @@ def evaluate(cfg: DictConfig):
317301
# initialize solver
318302
solver = ppsci.solver.Solver(
319303
model,
320-
output_dir=cfg.output_dir,
321-
seed=cfg.seed,
322304
validator=validator_pde,
323-
pretrained_model_path=cfg.EVAL.pretrained_model_path,
324-
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
305+
cfg=cfg,
325306
)
326307
# evaluate
327308
solver.eval()

0 commit comments

Comments
 (0)