Skip to content

Commit e9a8d9b

Browse files
Amend ci/cd, update predict.py
1 parent 45b6f27 commit e9a8d9b

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

.github/workflows/run_test.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ jobs:
2020
python -m pip install --upgrade pip && pip install -r requirements.txt
2121
- name: Lint with Ruff
2222
run: |
23-
pip install ruff && ruff
23+
pip install ruff && ruff check . && ruff format .
2424
continue-on-error: true
2525
- name: Test with pytest
2626
run: |
2727
pip install pytest && pytest test_app_cicd.py
28+
continue-on-error: true
2829
- name: Generate Coverage Report
2930
run: |
3031
coverage report -m

predict.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
88
"""
99
import json
10-
from os import getenv
1110
from typing import Any, Self, Callable, Literal
1211

1312
import numpy as np
1413
import pandas as pd
15-
from dotenv import load_dotenv
14+
from dotenv import dotenv_values
1615
from openai import OpenAI as LlmClient
1716
from openai.types.chat import ChatCompletion
1817
from openai.types.chat.chat_completion import Choice
@@ -37,15 +36,15 @@ def __init__(self: Self, env_file_path: str = None) -> None:
3736

3837
# If .env filepath is supplied, use it. Or else '.env' is used.
3938
env_file_path = env_file_path or ".env"
40-
load_dotenv(dotenv_path=env_file_path)
39+
_config: dict = dotenv_values(env_file_path)
4140

42-
self.prediction_api: str = getenv("DEFAULT_PREDICTION_API")
41+
self.prediction_api: str = _config.get("DEFAULT_PREDICTION_API")
4342
print(f"\t[INFO]\tAI backend: `{self.prediction_api}`.")
4443

4544
if "LLM" in self.prediction_api:
46-
self.base_url: str = getenv("LLM_BASE_URL")
47-
self.llm_api_key: str = getenv("LLM_API_KEY")
48-
self.llm_model: str = getenv("LLM_MODEL")
45+
self.base_url: str = _config.get("LLM_BASE_URL")
46+
self.llm_api_key: str = _config.get("LLM_API_KEY")
47+
self.llm_model: str = _config.get("LLM_MODEL")
4948

5049
match self.prediction_api:
5150
case "LLM":
@@ -56,19 +55,19 @@ def __init__(self: Self, env_file_path: str = None) -> None:
5655
"Predict probability of uptrend"
5756
"(respond with a single number between 0.0 and 100.0; "
5857
"no other information!)")
59-
self.lower_prob: float = float(getenv("LOWER_PROB") or 20)
60-
self.upper_prob: float = float(getenv("UPPER_PROB") or 80)
58+
self.lower_prob: float = float(_config.get("LOWER_PROB", 20))
59+
self.upper_prob: float = float(_config.get("UPPER_PROB", 80))
6160
if not 0.0 <= self.lower_prob <= self.upper_prob <= 100.0:
6261
self.lower_prob = 20.0
6362
self.upper_prob = 80.0
6463

6564
case "PANDAS":
6665
self.indicators: set[str] = set(json.loads(
67-
getenv("PREDICTION_INDICATORS_JSON")
66+
_config.get("PREDICTION_INDICATORS_JSON")
6867
))
69-
self.price_type_column_name: str = getenv("PREDICTION_OPERATIONAL_PRICE_TYPE")
68+
self.price_type_column_name: str = _config.get("PREDICTION_OPERATIONAL_PRICE_TYPE")
7069
# Take an n-period lag for better signals
71-
self.wait_for_n_signal_lags: int = int(getenv("PREDICTION_GLOBAL_SIGNAL_LAG"))
70+
self.wait_for_n_signal_lags: int = int(_config.get("PREDICTION_GLOBAL_SIGNAL_LAG"))
7271
self.df: pd.DataFrame | None = None
7372

7473
case _:

test_app_cicd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class TestLLM:
1313
Test LLM API predictions (5/5 passed expected, however at least 1/5 is fine)
1414
"""
1515

16+
1617
def test_any(self):
1718
"""
1819
Run abstract LLM prediction on test data (check if not None)

0 commit comments

Comments
 (0)