1+ from pathlib import Path
2+
13import yaml
24import click
35import openai
@@ -47,22 +49,22 @@ class ConfigHelper:
4749 presence_penalty : float = DEFAULT_PRESENCE_PENALTY
4850
4951 @classmethod
50- def from_file (cls ) -> 'ConfigHelper' :
51- if CONFIG_PATH .is_file ():
52- with open (CONFIG_PATH , "r" ) as f :
52+ def from_file (cls , config_path : Path = CONFIG_PATH ) -> 'ConfigHelper' :
53+ if config_path .is_file ():
54+ with open (config_path , "r" ) as f :
5355 config = yaml .safe_load (f )
5456 return cls (** config )
5557 else :
5658 click .echo (click .style ("No config file found, can't initialize config. "
5759 "Run 'askai config reset' to create a default config." , fg = "red" ))
5860 exit ()
5961
60- def input_model (self ) -> None :
62+ def input_model (self , max_input_tries : int = MAX_INPUT_TRIES ) -> None :
6163 model = input ("Choose model (1-4): " )
6264 num_of_tries = 1
6365
6466 while not _is_int (model ) or int (model ) not in range (1 , 5 ):
65- if num_of_tries >= MAX_INPUT_TRIES :
67+ if num_of_tries >= max_input_tries :
6668 click .echo (click .style ("Too many invalid tries. Aborted!" , fg = "red" ))
6769 exit (1 )
6870
@@ -74,56 +76,84 @@ def input_model(self) -> None:
7476 click .echo (click .style (f"Model chosen: { self .model } " , fg = "green" ))
7577 click .echo ()
7678
77- def input_num_answer (self ) -> None :
79+ def input_num_answer (self ,
80+ default_value : int = DEFAULT_NUM_ANSWERS ,
81+ min_value : int = OPENAI_NUM_ANSWERS_MIN ,
82+ max_input_tries : int = MAX_INPUT_TRIES ) -> None :
7883 self .num_answers = self ._input_integer (
79- default_value = DEFAULT_NUM_ANSWERS ,
80- predicate = lambda x : x >= OPENAI_NUM_ANSWERS_MIN
84+ default_value = default_value ,
85+ predicate = lambda x : x >= min_value ,
86+ max_input_tries = max_input_tries
8187 )
8288
83- def input_max_token (self ) -> None :
89+ def input_max_token (self ,
90+ default_value : int = DEFAULT_MAX_TOKENS ,
91+ min_value : int = OPENAI_MAX_TOKENS_MIN ,
92+ max_input_tries : int = MAX_INPUT_TRIES ) -> None :
8493 self .max_tokens = self ._input_integer (
85- default_value = DEFAULT_MAX_TOKENS ,
86- predicate = lambda x : x >= OPENAI_MAX_TOKENS_MIN
94+ default_value = default_value ,
95+ predicate = lambda x : x >= min_value ,
96+ max_input_tries = max_input_tries
8797 )
8898
89- def input_temperature (self ) -> None :
99+ def input_temperature (self ,
100+ default_value : float = DEFAULT_TEMPERATURE ,
101+ min_value : float = OPENAI_TEMPERATURE_MIN ,
102+ max_value : float = OPENAI_TEMPERATURE_MAX ,
103+ max_input_tries : int = MAX_INPUT_TRIES ) -> None :
90104 self .temperature = self ._input_float (
91- default_value = DEFAULT_TEMPERATURE ,
92- predicate = lambda x : OPENAI_TEMPERATURE_MIN <= x <= OPENAI_TEMPERATURE_MAX
105+ default_value = default_value ,
106+ predicate = lambda x : min_value <= x <= max_value ,
107+ max_input_tries = max_input_tries
93108 )
94109
95- def input_top_p (self ) -> None :
110+ def input_top_p (self ,
111+ default_value : float = DEFAULT_TOP_P ,
112+ min_value : float = OPENAI_TOP_P_MIN ,
113+ max_value : float = OPENAI_TOP_P_MAX ,
114+ max_input_tries : int = MAX_INPUT_TRIES ) -> None :
96115 self .top_p = self ._input_float (
97- default_value = DEFAULT_TOP_P ,
98- predicate = lambda x : OPENAI_TOP_P_MIN <= x <= OPENAI_TOP_P_MAX
116+ default_value = default_value ,
117+ predicate = lambda x : min_value <= x <= max_value ,
118+ max_input_tries = max_input_tries
99119 )
100120
101- def input_frequency_penalty (self ) -> None :
121+ def input_frequency_penalty (self ,
122+ default_value : float = DEFAULT_FREQUENCY_PENALTY ,
123+ min_value : float = OPENAI_FREQUENCY_PENALTY_MIN ,
124+ max_value : float = OPENAI_FREQUENCY_PENALTY_MAX ,
125+ max_input_tries : int = MAX_INPUT_TRIES ) -> None :
102126 self .frequency_penalty = self ._input_float (
103- default_value = DEFAULT_FREQUENCY_PENALTY ,
104- predicate = lambda x : OPENAI_FREQUENCY_PENALTY_MIN <= x <= OPENAI_FREQUENCY_PENALTY_MAX
127+ default_value = default_value ,
128+ predicate = lambda x : min_value <= x <= max_value ,
129+ max_input_tries = max_input_tries
105130 )
106131
107- def input_presence_penalty (self ) -> None :
132+ def input_presence_penalty (self ,
133+ default_value : float = DEFAULT_PRESENCE_PENALTY ,
134+ min_value : float = OPENAI_PRESENCE_PENALTY_MIN ,
135+ max_value : float = OPENAI_PRESENCE_PENALTY_MAX ,
136+ max_input_tries : int = MAX_INPUT_TRIES ) -> None :
108137 self .presence_penalty = self ._input_float (
109- default_value = DEFAULT_PRESENCE_PENALTY ,
110- predicate = lambda x : OPENAI_PRESENCE_PENALTY_MIN <= x <= OPENAI_PRESENCE_PENALTY_MAX
138+ default_value = default_value ,
139+ predicate = lambda x : min_value <= x <= max_value ,
140+ max_input_tries = max_input_tries
111141 )
112142
113143 def as_dict (self ) -> dict :
114144 return asdict (self )
115145
116- def update (self ) -> None :
146+ def update (self , config_path : Path = CONFIG_PATH ) -> None :
117147 config = self .as_dict ()
118- with open (CONFIG_PATH , "w" ) as f :
148+ with open (config_path , "w" ) as f :
119149 yaml .dump (config , f )
120150
121151 click .echo (click .style ("Config updated successfully!" , fg = "green" ))
122152
123153 @staticmethod
124- def reset () -> None :
154+ def reset (config_path : Path = CONFIG_PATH ) -> None :
125155 config = ConfigHelper ().as_dict () # Create config with default values
126- with open (CONFIG_PATH , "w" ) as f :
156+ with open (config_path , "w" ) as f :
127157 yaml .dump (config , f )
128158
129159 click .echo ("\n Default config has been created with the following values:" )
@@ -133,12 +163,12 @@ def reset() -> None:
133163 click .echo ("To change the config, please see: 'askai config --help'\n " )
134164
135165 @staticmethod
136- def show () -> None :
137- if not CONFIG_PATH .is_file ():
166+ def show (config_path : Path = CONFIG_PATH ) -> None :
167+ if not config_path .is_file ():
138168 click .echo ("No config exists. Please reset the config ('askai config reset') "
139169 "or see 'askai config --help'.\n " )
140170 else :
141- with open (CONFIG_PATH , "r" ) as f :
171+ with open (config_path , "r" ) as f :
142172 try :
143173 config = yaml .safe_load (f )
144174 for key , value in config .items ():
@@ -148,8 +178,9 @@ def show() -> None:
148178
149179 @staticmethod
150180 def _input_integer (default_value : int ,
151- predicate : Callable [[int ], bool ] = lambda x : True ) -> int :
152- for _ in range (MAX_INPUT_TRIES ):
181+ predicate : Callable [[int ], bool ] = lambda x : True ,
182+ max_input_tries : int = MAX_INPUT_TRIES ) -> int :
183+ for _ in range (max_input_tries ):
153184 input_value = input (f"Choose (press enter for default = { default_value } ): " )
154185
155186 if input_value == "" :
@@ -173,8 +204,9 @@ def _input_integer(default_value: int,
173204
174205 @staticmethod
175206 def _input_float (default_value : float ,
176- predicate : Callable [[float ], bool ] = lambda x : True ) -> float :
177- for _ in range (MAX_INPUT_TRIES ):
207+ predicate : Callable [[float ], bool ] = lambda x : True ,
208+ max_input_tries : int = MAX_INPUT_TRIES ) -> float :
209+ for _ in range (max_input_tries ):
178210 input_value = input (f"Choose (press enter for default = { default_value } ): " )
179211
180212 if input_value == "" :
0 commit comments