Skip to content

finetuning encoder for Downstream Task #3

@SnehaChandna

Description

@SnehaChandna

I have a task for classification on progression of disease

  • i have trained longiseg on my data
  • i wanted to extract the trained encoder from the pth file and use the pretrained model and then fine tune on my downstream classification task

setup

i have initialized the predictor using the below code with the pretrained model.

 predictor = LongiSegPredictor(
       tile_step_size=TILE_STEP_SIZE, use_gaussian=USE_GAUSSIAN, use_mirroring=not DISABLE_TTA,
       perform_everything_on_device=True, device=device, verbose=VERBOSE,
       verbose_preprocessing=VERBOSE, allow_tqdm=not DISABLE_PROGRESS_BAR
 )

predictor.initialize_from_trained_model_folder(
           model_training_output_dir=MODEL_TRAINING_OUTPUT_DIR,
           use_folds=USE_FOLDS, checkpoint_name=CHECKPOINT_NAME
 )

i loaded the encoder using this

    original_network = predictor.network
    encoder = original_network.encoder

the model looked something like this

encoder PlainConvEncoder(
  (stages): Sequential(
    (0): Sequential(
      (0): StackedConvBlocks(
        (convs): Sequential(
          (0): ConvDropoutNormReLU(
            (conv): Conv3d(8, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(8, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (1): ConvDropoutNormReLU(
            (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
        )
      )
    )
    (1): Sequential(
      (0): StackedConvBlocks(
        (convs): Sequential(
          (0): ConvDropoutNormReLU(
            (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
            (norm): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (1): ConvDropoutNormReLU(
            (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
        )
      )
    )
    (2): Sequential(
      (0): StackedConvBlocks(
        (convs): Sequential(
          (0): ConvDropoutNormReLU(
            (conv): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
            (norm): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (1): ConvDropoutNormReLU(
            (conv): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
        )
      )
    )
    (3): Sequential(
      (0): StackedConvBlocks(
        (convs): Sequential(
          (0): ConvDropoutNormReLU(
            (conv): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
            (norm): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (1): ConvDropoutNormReLU(
            (conv): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
        )
      )
    )
    (4): Sequential(
      (0): StackedConvBlocks(
        (convs): Sequential(
          (0): ConvDropoutNormReLU(
            (conv): Conv3d(256, 320, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
            (norm): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(256, 320, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (1): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (1): ConvDropoutNormReLU(
            (conv): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
        )
      )
    )
    (5): Sequential(
      (0): StackedConvBlocks(
        (convs): Sequential(
          (0): ConvDropoutNormReLU(
            (conv): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1))
            (norm): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1))
              (1): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (1): ConvDropoutNormReLU(
            (conv): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
        )
      )
    )
  )
)

my question

i have a data that has 4 modalities and 2 timepoints which i have loaded using

list_of_lists, output_filenames_truncated, seg_from_prev_stage_files = \
            predictor._manage_input_and_output_lists(
                source_folder=source_folder, output_folder=temp_output_dir,
                patient_json=patient_json_data, folder_with_segs_from_prev_stage=None,
                overwrite=True, part_id=0, num_parts=1, save_probabilities=False
            )

data_iterator = predictor._internal_get_data_iterator_from_lists_of_filenames(
            input_list_of_lists=list_of_lists, seg_from_prev_stage_files=seg_from_prev_stage_files,
            output_filenames_truncated=output_filenames_truncated, is_longitudinal=predictor.is_longitudinal,
            num_processes=num_processes_preprocessing
        )

the data_iterator outputs a [8, 144, 230, 401] that is correct as 4 modalities for 2 timepoints data_c, data_p stacked using np.vstack(.)

but the problem is that the encoder input is 8 channels
what i understood was the encoder should act separately on data_c [4, 144, 230, 401] and data_p [4, 144, 230, 401] and that whole thing is send to the difference weighing block which can be further send to a Neural net for my downstream classification
but that is not possible using the current encoder as it intakes a 8 channel input
but this whole pretrained model is working correctly when i use if for inference using code in longiseg.inference.predict_from_raw_data_longi.py

I cannot understand where am i going wrong .In using the preprocessing or the model loading ?
can you help me with this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions