77
88"""
99import json
10- from os import getenv
1110from typing import Any , Self , Callable , Literal
1211
1312import numpy as np
1413import pandas as pd
15- from dotenv import load_dotenv
14+ from dotenv import dotenv_values
1615from openai import OpenAI as LlmClient
1716from openai .types .chat import ChatCompletion
1817from openai .types .chat .chat_completion import Choice
@@ -37,15 +36,15 @@ def __init__(self: Self, env_file_path: str = None) -> None:
3736
3837 # If .env filepath is supplied, use it. Or else '.env' is used.
3938 env_file_path = env_file_path or ".env"
40- load_dotenv ( dotenv_path = env_file_path )
39+ _config : dict = dotenv_values ( env_file_path )
4140
42- self .prediction_api : str = getenv ("DEFAULT_PREDICTION_API" )
41+ self .prediction_api : str = _config . get ("DEFAULT_PREDICTION_API" )
4342 print (f"\t [INFO]\t AI backend: `{ self .prediction_api } `." )
4443
4544 if "LLM" in self .prediction_api :
46- self .base_url : str = getenv ("LLM_BASE_URL" )
47- self .llm_api_key : str = getenv ("LLM_API_KEY" )
48- self .llm_model : str = getenv ("LLM_MODEL" )
45+ self .base_url : str = _config . get ("LLM_BASE_URL" )
46+ self .llm_api_key : str = _config . get ("LLM_API_KEY" )
47+ self .llm_model : str = _config . get ("LLM_MODEL" )
4948
5049 match self .prediction_api :
5150 case "LLM" :
@@ -56,19 +55,19 @@ def __init__(self: Self, env_file_path: str = None) -> None:
5655 "Predict probability of uptrend"
5756 "(respond with a single number between 0.0 and 100.0; "
5857 "no other information!)" )
59- self .lower_prob : float = float (getenv ("LOWER_PROB" ) or 20 )
60- self .upper_prob : float = float (getenv ("UPPER_PROB" ) or 80 )
58+ self .lower_prob : float = float (_config . get ("LOWER_PROB" , 20 ) )
59+ self .upper_prob : float = float (_config . get ("UPPER_PROB" , 80 ) )
6160 if not 0.0 <= self .lower_prob <= self .upper_prob <= 100.0 :
6261 self .lower_prob = 20.0
6362 self .upper_prob = 80.0
6463
6564 case "PANDAS" :
6665 self .indicators : set [str ] = set (json .loads (
67- getenv ("PREDICTION_INDICATORS_JSON" )
66+ _config . get ("PREDICTION_INDICATORS_JSON" )
6867 ))
69- self .price_type_column_name : str = getenv ("PREDICTION_OPERATIONAL_PRICE_TYPE" )
68+ self .price_type_column_name : str = _config . get ("PREDICTION_OPERATIONAL_PRICE_TYPE" )
7069 # Take an n-period lag for better signals
71- self .wait_for_n_signal_lags : int = int (getenv ("PREDICTION_GLOBAL_SIGNAL_LAG" ))
70+ self .wait_for_n_signal_lags : int = int (_config . get ("PREDICTION_GLOBAL_SIGNAL_LAG" ))
7271 self .df : pd .DataFrame | None = None
7372
7473 case _:
0 commit comments