Skip to content

Commit b02a2b1

Browse files
[Fix] Fix ldc_2d infer key order (#1087)
* update multi devices document * fix index * update doc * fix ldc 2d inference key order
1 parent 5192d93 commit b02a2b1

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

examples/ldc/ldc_2d_Re1000_plain.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def inference(cfg: DictConfig):
278278
# mapping data to cfg.INFER.output_keys
279279
output_dict = {
280280
store_key: output_dict[infer_key]
281-
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
281+
for store_key, infer_key in zip(
282+
sorted(cfg.MODEL.output_keys), output_dict.keys()
283+
)
282284
}
283285
U_pred = np.sqrt(output_dict["u"] ** 2 + output_dict["v"] ** 2).reshape(
284286
[len(x_star), len(y_star)]

examples/ldc/ldc_2d_Re3200_piratenet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def inference(cfg: DictConfig):
293293
# mapping data to cfg.INFER.output_keys
294294
output_dict = {
295295
store_key: output_dict[infer_key]
296-
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
296+
for store_key, infer_key in zip(
297+
sorted(cfg.MODEL.output_keys), output_dict.keys()
298+
)
297299
}
298300
U_pred = np.sqrt(output_dict["u"] ** 2 + output_dict["v"] ** 2).reshape(
299301
[len(x_star), len(y_star)]

examples/ldc/ldc_2d_Re3200_sota.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ def inference(cfg: DictConfig):
289289
# mapping data to cfg.INFER.output_keys
290290
output_dict = {
291291
store_key: output_dict[infer_key]
292-
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
292+
for store_key, infer_key in zip(
293+
sorted(cfg.MODEL.output_keys), output_dict.keys()
294+
)
293295
}
294296
U_pred = np.sqrt(output_dict["u"] ** 2 + output_dict["v"] ** 2).reshape(
295297
[len(x_star), len(y_star)]

0 commit comments

Comments
 (0)