Skip to content

Commit a6613f4

Browse files
added check for 4d images in synthetic b0 generation
1 parent 9b36d4d commit a6613f4

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

micaflow/scripts/synth_b0.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,26 @@ def main():
656656
os.makedirs(temp_dir, exist_ok=True)
657657
print(f"\n{CYAN}Temporary directory:{RESET} {temp_dir}")
658658

659+
# --- FIX START: Handle 4D input dimensions for SynB0-DISCO compatibility ---
660+
# Ensure inputs are strictly 3D. If 4D with N=1 (e.g. 240x240x220x1),
661+
# squeeze them to prevent NumPy "inhomogeneous shape" errors during inference.
662+
663+
t1_img_check = nib.load(t1_path)
664+
if len(t1_img_check.shape) > 3 and t1_img_check.shape[3] == 1:
665+
print(f"{YELLOW}Warning: T1w input is 4D {t1_img_check.shape}. Squeezing to 3D for inference compatibility...{RESET}")
666+
t1_3d_path = os.path.join(temp_dir, "t1_3d_input.nii.gz")
667+
# Squeeze data and save to temp file
668+
nib.save(nib.Nifti1Image(t1_img_check.get_fdata().squeeze(), t1_img_check.affine), t1_3d_path)
669+
t1_path = t1_3d_path
670+
671+
b0_img_check = nib.load(b0_path)
672+
if len(b0_img_check.shape) > 3 and b0_img_check.shape[3] == 1:
673+
print(f"{YELLOW}Warning: B0 input is 4D {b0_img_check.shape}. Squeezing to 3D for inference compatibility...{RESET}")
674+
b0_3d_path = os.path.join(temp_dir, "b0_3d_input.nii.gz")
675+
nib.save(nib.Nifti1Image(b0_img_check.get_fdata().squeeze(), b0_img_check.affine), b0_3d_path)
676+
b0_path = b0_3d_path
677+
# --- FIX END ---
678+
659679
# Find models
660680
models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'models')
661681
print(f"{CYAN}Models directory:{RESET} {models_dir}")

0 commit comments

Comments
 (0)