@@ -107,37 +107,72 @@ def prompt_for_config(use_ngrok: bool = None, port: int = None, ngrok_auth_token
107107 click .echo ("\n ⚡ Model Optimization Settings" )
108108 click .echo ("─────────────────────────────" )
109109
110- config ["enable_quantization" ] = click .confirm (
111- "Enable model quantization?" ,
112- default = config .get ("enable_quantization" , ENABLE_QUANTIZATION )
110+ # Show current values for reference
111+ click .echo ("\n Current optimization settings:" )
112+ click .echo (f" Quantization: { 'Enabled' if config .get ('enable_quantization' , ENABLE_QUANTIZATION ) else 'Disabled' } " )
113+ if config .get ('enable_quantization' , ENABLE_QUANTIZATION ):
114+ click .echo (f" Quantization Type: { config .get ('quantization_type' , QUANTIZATION_TYPE )} " )
115+ click .echo (f" CPU Offloading: { 'Enabled' if config .get ('enable_cpu_offloading' , ENABLE_CPU_OFFLOADING ) else 'Disabled' } " )
116+ click .echo (f" Attention Slicing: { 'Enabled' if config .get ('enable_attention_slicing' , ENABLE_ATTENTION_SLICING ) else 'Disabled' } " )
117+ click .echo (f" Flash Attention: { 'Enabled' if config .get ('enable_flash_attention' , ENABLE_FLASH_ATTENTION ) else 'Disabled' } " )
118+ click .echo (f" Better Transformer: { 'Enabled' if config .get ('enable_bettertransformer' , ENABLE_BETTERTRANSFORMER ) else 'Disabled' } " )
119+
120+ # Ask if user wants to configure optimization settings
121+ configure_optimization = click .confirm (
122+ "\n Would you like to configure model optimization settings?" ,
123+ default = True # Default to Yes for optimization settings
113124 )
114125
115- if config ["enable_quantization" ]:
116- config ["quantization_type" ] = click .prompt (
117- "Quantization type (fp16/int8/int4)" ,
118- default = config .get ("quantization_type" , QUANTIZATION_TYPE ),
119- type = click .Choice (["fp16" , "int8" , "int4" ])
126+ if configure_optimization :
127+ config ["enable_quantization" ] = click .confirm (
128+ "Enable model quantization?" ,
129+ default = config .get ("enable_quantization" , ENABLE_QUANTIZATION )
120130 )
121131
122- config ["enable_cpu_offloading" ] = click .confirm (
123- "Enable CPU offloading?" ,
124- default = config .get ("enable_cpu_offloading" , ENABLE_CPU_OFFLOADING )
125- )
132+ if config ["enable_quantization" ]:
133+ config ["quantization_type" ] = click .prompt (
134+ "Quantization type (fp16/int8/int4)" ,
135+ default = config .get ("quantization_type" , QUANTIZATION_TYPE ),
136+ type = click .Choice (["fp16" , "int8" , "int4" ])
137+ )
126138
127- config ["enable_attention_slicing " ] = click .confirm (
128- "Enable attention slicing ?" ,
129- default = config .get ("enable_attention_slicing " , ENABLE_ATTENTION_SLICING )
130- )
139+ config ["enable_cpu_offloading " ] = click .confirm (
140+ "Enable CPU offloading ?" ,
141+ default = config .get ("enable_cpu_offloading " , ENABLE_CPU_OFFLOADING )
142+ )
131143
132- config ["enable_flash_attention " ] = click .confirm (
133- "Enable flash attention?" ,
134- default = config .get ("enable_flash_attention " , ENABLE_FLASH_ATTENTION )
135- )
144+ config ["enable_attention_slicing " ] = click .confirm (
145+ "Enable attention slicing ?" ,
146+ default = config .get ("enable_attention_slicing " , ENABLE_ATTENTION_SLICING )
147+ )
136148
137- config ["enable_better_transformer" ] = click .confirm (
138- "Enable better transformer?" ,
139- default = config .get ("enable_bettertransformer" , ENABLE_BETTERTRANSFORMER )
140- )
149+ config ["enable_flash_attention" ] = click .confirm (
150+ "Enable flash attention?" ,
151+ default = config .get ("enable_flash_attention" , ENABLE_FLASH_ATTENTION )
152+ )
153+
154+ config ["enable_better_transformer" ] = click .confirm (
155+ "Enable better transformer?" ,
156+ default = config .get ("enable_bettertransformer" , ENABLE_BETTERTRANSFORMER )
157+ )
158+
159+ click .echo ("\n ✅ Optimization settings updated!" )
160+ else :
161+ # If user doesn't want to configure, use the current values or defaults
162+ if 'enable_quantization' not in config :
163+ config ["enable_quantization" ] = ENABLE_QUANTIZATION
164+ if config ["enable_quantization" ] and 'quantization_type' not in config :
165+ config ["quantization_type" ] = QUANTIZATION_TYPE
166+ if 'enable_cpu_offloading' not in config :
167+ config ["enable_cpu_offloading" ] = ENABLE_CPU_OFFLOADING
168+ if 'enable_attention_slicing' not in config :
169+ config ["enable_attention_slicing" ] = ENABLE_ATTENTION_SLICING
170+ if 'enable_flash_attention' not in config :
171+ config ["enable_flash_attention" ] = ENABLE_FLASH_ATTENTION
172+ if 'enable_bettertransformer' not in config :
173+ config ["enable_bettertransformer" ] = ENABLE_BETTERTRANSFORMER
174+
175+ click .echo ("\n Using current optimization settings." )
141176
142177 # Advanced Settings
143178 # ----------------
@@ -150,40 +185,89 @@ def prompt_for_config(use_ngrok: bool = None, port: int = None, ngrok_auth_token
150185 type = int
151186 )
152187
153- # Generation Parameters
154- # -------------------
155- click .echo ("\n 🔄 Generation Parameters" )
156- click .echo ("─────────────────────" )
157-
158- config ["max_length" ] = click .prompt (
159- "Maximum generation length (tokens)" ,
160- default = config .get ("max_length" , 8192 ),
161- type = int
188+ # Response Quality Settings
189+ # -----------------------
190+ click .echo ("\n 🎯 Response Quality Settings" )
191+ click .echo ("───────────────────────────" )
192+
193+ # Show current values for reference with descriptions
194+ click .echo ("\n Current response quality settings:" )
195+ click .echo (f" Max Length: { config .get ('max_length' , 8192 )} tokens - Maximum number of tokens in the generated response" )
196+ click .echo (f" Temperature: { config .get ('temperature' , 0.7 )} - Controls randomness (higher = more creative, lower = more focused)" )
197+ click .echo (f" Top-p: { config .get ('top_p' , 0.9 )} - Nucleus sampling parameter (higher = more diverse responses)" )
198+ click .echo (f" Top-k: { config .get ('top_k' , 80 )} - Limits vocabulary to top K tokens (higher = more diverse vocabulary)" )
199+ click .echo (f" Repetition Penalty: { config .get ('repetition_penalty' , 1.15 )} - Penalizes repetition (higher = less repetition)" )
200+ click .echo (f" Max Time: { config .get ('max_time' , 120.0 )} seconds - Maximum time allowed for generation" )
201+
202+ # Ask if user wants to configure response quality settings
203+ configure_response_quality = click .confirm (
204+ "\n Would you like to configure response quality settings?" ,
205+ default = False # Default to No
162206 )
163207
164- config ["temperature" ] = click .prompt (
165- "Temperature (0.1-1.0)" ,
166- default = config .get ("temperature" , 0.7 ),
167- type = float
168- )
208+ if configure_response_quality :
209+ # If user wants to configure, show the prompts with descriptions
210+ config ["max_length" ] = click .prompt (
211+ "Maximum generation length in tokens (higher = longer responses, but slower)" ,
212+ default = config .get ("max_length" , 8192 ),
213+ type = int
214+ )
169215
170- config ["top_p " ] = click .prompt (
171- "Top-p (0.1-1.0)" ,
172- default = config .get ("top_p " , 0.9 ),
173- type = float
174- )
216+ config ["temperature " ] = click .prompt (
217+ "Temperature (0.1-1.0, higher = more creative, lower = more focused )" ,
218+ default = config .get ("temperature " , 0.7 ),
219+ type = float
220+ )
175221
176- config ["top_k " ] = click .prompt (
177- "Top-k (1-100 )" ,
178- default = config .get ("top_k " , 80 ),
179- type = int
180- )
222+ config ["top_p " ] = click .prompt (
223+ "Top-p (0.1-1.0, higher = more diverse responses )" ,
224+ default = config .get ("top_p " , 0.9 ),
225+ type = float
226+ )
181227
182- config ["repetition_penalty" ] = click .prompt (
183- "Repetition penalty (1.0-2.0)" ,
184- default = config .get ("repetition_penalty" , 1.15 ),
185- type = float
186- )
228+ config ["top_k" ] = click .prompt (
229+ "Top-k (1-100, higher = more diverse vocabulary)" ,
230+ default = config .get ("top_k" , 80 ),
231+ type = int
232+ )
233+
234+ config ["repetition_penalty" ] = click .prompt (
235+ "Repetition penalty (1.0-2.0, higher = less repetition)" ,
236+ default = config .get ("repetition_penalty" , 1.15 ),
237+ type = float
238+ )
239+
240+ config ["max_time" ] = click .prompt (
241+ "Maximum generation time in seconds (higher = more complete responses, but slower)" ,
242+ default = config .get ("max_time" , 120.0 ),
243+ type = float
244+ )
245+
246+ click .echo ("\n ✅ Response quality settings updated!" )
247+ else :
248+ # If user doesn't want to configure, use the current values or defaults
249+ if 'max_length' not in config :
250+ config ["max_length" ] = 8192
251+ if 'temperature' not in config :
252+ config ["temperature" ] = 0.7
253+ if 'top_p' not in config :
254+ config ["top_p" ] = 0.9
255+ if 'top_k' not in config :
256+ config ["top_k" ] = 80
257+ if 'repetition_penalty' not in config :
258+ config ["repetition_penalty" ] = 1.15
259+ if 'max_time' not in config :
260+ config ["max_time" ] = 120.0
261+
262+ click .echo ("\n Using default response quality settings." )
263+
264+ # Set environment variables for these settings
265+ os .environ ["DEFAULT_MAX_LENGTH" ] = str (config ["max_length" ])
266+ os .environ ["DEFAULT_TEMPERATURE" ] = str (config ["temperature" ])
267+ os .environ ["DEFAULT_TOP_P" ] = str (config ["top_p" ])
268+ os .environ ["DEFAULT_TOP_K" ] = str (config ["top_k" ])
269+ os .environ ["DEFAULT_REPETITION_PENALTY" ] = str (config ["repetition_penalty" ])
270+ os .environ ["DEFAULT_MAX_TIME" ] = str (config ["max_time" ])
187271
188272 # Cache Settings
189273 # -------------
0 commit comments