diff --git a/ideeplc/__main__.py b/ideeplc/__main__.py index 3618e26..ca8db40 100644 --- a/ideeplc/__main__.py +++ b/ideeplc/__main__.py @@ -61,6 +61,13 @@ def _argument_parser() -> argparse.ArgumentParser: action="store_true", help="Flag to enable fine-tuning of the model.", ) + parser.add_argument( + "-m", + "--model", + type=str, + required=False, + help="Path to the pretrained model.", + ) parser.add_argument( "-l", "--log_level", diff --git a/ideeplc/ideeplc_core.py b/ideeplc/ideeplc_core.py index 0eddcf5..958f2b2 100644 --- a/ideeplc/ideeplc_core.py +++ b/ideeplc/ideeplc_core.py @@ -55,7 +55,7 @@ def get_model_save_path(): # If normal Python environment model_path = files("ideeplc.models").joinpath(model_name) - return model_path, model_dir, model_path + return model_path, model_dir def main(args): @@ -81,7 +81,14 @@ def main(args): # Load pre-trained model LOGGER.info("Loading pre-trained model") - best_model_path, model_dir, pretrained_model = get_model_save_path() + pretrained_model, model_dir = get_model_save_path() + if args.model: + try: + logging.info(f"Using user-specified model path: {args.model}") + pretrained_model = Path(args.model) + except Exception as e: + LOGGER.error(f"Invalid model path provided: {e}") + raise e model.load_state_dict( torch.load(pretrained_model, map_location=device), strict=False ) diff --git a/ideeplc/models/pretrained_model.pth b/ideeplc/models/pretrained_model.pth index d0e40f3..fa22dc6 100644 Binary files a/ideeplc/models/pretrained_model.pth and b/ideeplc/models/pretrained_model.pth differ diff --git a/tests/test_model_save_path.py b/tests/test_model_save_path.py index 59e6a38..935e0af 100644 --- a/tests/test_model_save_path.py +++ b/tests/test_model_save_path.py @@ -5,9 +5,9 @@ def test_get_model_save_path(): """Test the get_model_save_path function.""" - model_path, model_dir, pretrained_path = get_model_save_path() - assert isinstance(model_path, Path), "Model path should be a Path object" + pretrained_model, model_dir = get_model_save_path() + assert isinstance(pretrained_model, Path), "Model path should be a Path object" assert isinstance(model_dir, Path), "Model directory should be a Path object" - assert model_path.name == "pretrained_model.pth", "Model name should be 'pretrained_model.pth'" + assert pretrained_model.name == "pretrained_model.pth", "Model name should be 'pretrained_model.pth'" diff --git a/tests/test_prediction_calibration.py b/tests/test_prediction_calibration.py index 79a3a92..4314cb2 100644 --- a/tests/test_prediction_calibration.py +++ b/tests/test_prediction_calibration.py @@ -11,7 +11,7 @@ def test_predict(): """Test the predict function.""" # Mock data and model config = get_config() - best_model_path, model_dir, pretrained_model = get_model_save_path() + pretrained_model, model_dir = get_model_save_path() test_csv_path = "ideeplc/example_input/Hela_deeprt.csv" # Path to a sample test CSV file matrix_input, x_shape = data_initialize(csv_path=test_csv_path)