@@ -427,11 +427,21 @@ def _split_condition(single_condition_dict, splits_dict):
427427 offset = 0
428428
429429 for stage , stage_len in splits_dict .items ():
430+ # filter keys to return dict, only batch items are listed here
431+ # Equations are NEVER dataloaded, None variables are not dataloaded
432+ filtered_keys = []
433+ for k in single_condition_dict .keys ():
434+ if k == "equation" or (
435+ k == "conditional_variables"
436+ and single_condition_dict [k ] is None
437+ ):
438+ continue
439+ else :
440+ filtered_keys .append (k )
441+ # return the correct splitting
430442 to_return_dict [stage ] = {
431- k : v [offset : offset + stage_len ]
432- for k , v in single_condition_dict .items ()
433- if k != "equation"
434- # Equations are NEVER dataloaded
443+ k : single_condition_dict [k ][offset : offset + stage_len ]
444+ for k in filtered_keys
435445 }
436446 if offset + stage_len >= len_condition :
437447 offset = len_condition - 1
@@ -456,6 +466,8 @@ def _apply_shuffle(condition_dict, len_data):
456466 for k , v in condition_dict .items ():
457467 if k == "equation" :
458468 continue
469+ if k == "conditional_variables" and condition_dict [k ] is None :
470+ continue
459471 if isinstance (v , list ):
460472 condition_dict [k ] = [v [i ] for i in idx ]
461473 elif isinstance (v , LabelTensor ):
0 commit comments