Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ideeplc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def _argument_parser() -> argparse.ArgumentParser:
action="store_true",
help="Flag to enable fine-tuning of the model.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=False,
help="Path to the pretrained model.",
)
parser.add_argument(
"-l",
"--log_level",
Expand Down
11 changes: 9 additions & 2 deletions ideeplc/ideeplc_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_model_save_path():
# If normal Python environment
model_path = files("ideeplc.models").joinpath(model_name)

return model_path, model_dir, model_path
return model_path, model_dir


def main(args):
Expand All @@ -81,7 +81,14 @@ def main(args):

# Load pre-trained model
LOGGER.info("Loading pre-trained model")
best_model_path, model_dir, pretrained_model = get_model_save_path()
pretrained_model, model_dir = get_model_save_path()
if args.model:
try:
logging.info(f"Using user-specified model path: {args.model}")
pretrained_model = Path(args.model)
except Exception as e:
LOGGER.error(f"Invalid model path provided: {e}")
raise e
model.load_state_dict(
torch.load(pretrained_model, map_location=device), strict=False
)
Expand Down
Binary file modified ideeplc/models/pretrained_model.pth
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/test_model_save_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

def test_get_model_save_path():
"""Test the get_model_save_path function."""
model_path, model_dir, pretrained_path = get_model_save_path()
assert isinstance(model_path, Path), "Model path should be a Path object"
pretrained_model, model_dir = get_model_save_path()
assert isinstance(pretrained_model, Path), "Model path should be a Path object"
assert isinstance(model_dir, Path), "Model directory should be a Path object"
assert model_path.name == "pretrained_model.pth", "Model name should be 'pretrained_model.pth'"
assert pretrained_model.name == "pretrained_model.pth", "Model name should be 'pretrained_model.pth'"


2 changes: 1 addition & 1 deletion tests/test_prediction_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_predict():
"""Test the predict function."""
# Mock data and model
config = get_config()
best_model_path, model_dir, pretrained_model = get_model_save_path()
pretrained_model, model_dir = get_model_save_path()

test_csv_path = "ideeplc/example_input/Hela_deeprt.csv" # Path to a sample test CSV file
matrix_input, x_shape = data_initialize(csv_path=test_csv_path)
Expand Down