Skip to content

Commit 1d06df2

Browse files
committed
fixing keys in dataloader _split_condition
1 parent d4fa3ea commit 1d06df2

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

pina/data/data_module.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,21 @@ def _split_condition(single_condition_dict, splits_dict):
427427
offset = 0
428428

429429
for stage, stage_len in splits_dict.items():
430+
# filter keys to return dict, only batch items are listed here
431+
# Equations are NEVER dataloaded, None variables are not dataloaded
432+
filtered_keys = []
433+
for k in single_condition_dict.keys():
434+
if k == "equation" or (
435+
k == "conditional_variables"
436+
and single_condition_dict[k] is None
437+
):
438+
continue
439+
else:
440+
filtered_keys.append(k)
441+
# return the correct splitting
430442
to_return_dict[stage] = {
431-
k: v[offset : offset + stage_len]
432-
for k, v in single_condition_dict.items()
433-
if k != "equation"
434-
# Equations are NEVER dataloaded
443+
k: single_condition_dict[k][offset : offset + stage_len]
444+
for k in filtered_keys
435445
}
436446
if offset + stage_len >= len_condition:
437447
offset = len_condition - 1
@@ -456,6 +466,8 @@ def _apply_shuffle(condition_dict, len_data):
456466
for k, v in condition_dict.items():
457467
if k == "equation":
458468
continue
469+
if k == "conditional_variables" and condition_dict[k] is None:
470+
continue
459471
if isinstance(v, list):
460472
condition_dict[k] = [v[i] for i in idx]
461473
elif isinstance(v, LabelTensor):

0 commit comments

Comments
 (0)