Skip to content

Commit 0f907aa

Browse files
committed
fix tests
1 parent 2422b05 commit 0f907aa

File tree

15 files changed

+59
-21
lines changed

15 files changed

+59
-21
lines changed

optillm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
3535
from optillm.batching import RequestBatcher, BatchingError
3636
from optillm.conversation_logger import ConversationLogger
37+
import optillm.conversation_logger
3738

3839
# Setup logging
3940
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -1133,6 +1134,8 @@ def process_batch_requests(batch_requests):
11331134
log_dir=Path(server_config['conversation_log_dir']),
11341135
enabled=server_config['log_conversations']
11351136
)
1137+
# Set the global logger instance for access from approach modules
1138+
optillm.conversation_logger.set_global_logger(conversation_logger)
11361139
if server_config['log_conversations']:
11371140
logger.info(f"Conversation logging enabled. Logs will be saved to: {server_config['conversation_log_dir']}")
11381141

optillm/bon.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import optillm
3+
from optillm import conversation_logger
34

45
logger = logging.getLogger(__name__)
56

@@ -23,9 +24,9 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
2324
response = client.chat.completions.create(**provider_request)
2425

2526
# Log provider call
26-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
27+
if request_id:
2728
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
28-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
29+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
2930

3031
completions = [choice.message.content for choice in response.choices]
3132
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
@@ -47,9 +48,9 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
4748
response = client.chat.completions.create(**provider_request)
4849

4950
# Log provider call
50-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
51+
if request_id:
5152
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
52-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
53+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
5354

5455
completions.append(response.choices[0].message.content)
5556
bon_completion_tokens += response.usage.completion_tokens
@@ -84,9 +85,9 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
8485
rating_response = client.chat.completions.create(**provider_request)
8586

8687
# Log provider call
87-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
88+
if request_id:
8889
response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
89-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
90+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
9091

9192
bon_completion_tokens += rating_response.usage.completion_tokens
9293
try:

optillm/cepo/cepo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import yaml
33
import json
44
import optillm
5+
from optillm import conversation_logger
56

67
from dataclasses import dataclass
78
from typing import Literal, Any, Optional

optillm/conversation_logger.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13+
# Global logger instance - will be set by optillm.py
14+
_global_logger: Optional['ConversationLogger'] = None
15+
1316
@dataclass
1417
class ConversationEntry:
1518
"""Represents a single conversation entry being logged"""
@@ -240,4 +243,23 @@ def get_stats(self) -> Dict[str, Any]:
240243
"total_entries_approximate": total_entries
241244
})
242245

243-
return stats
246+
return stats
247+
248+
249+
# Module-level functions for easy access from approach modules
250+
def set_global_logger(logger_instance: 'ConversationLogger') -> None:
251+
"""Set the global logger instance (called by optillm.py)"""
252+
global _global_logger
253+
_global_logger = logger_instance
254+
255+
256+
def log_provider_call(request_id: str, provider_request: Dict[str, Any], provider_response: Dict[str, Any]) -> None:
257+
"""Log a provider call using the global logger instance"""
258+
if _global_logger and _global_logger.enabled:
259+
_global_logger.log_provider_call(request_id, provider_request, provider_response)
260+
261+
262+
def log_error(request_id: str, error_message: str) -> None:
263+
"""Log an error using the global logger instance"""
264+
if _global_logger and _global_logger.enabled:
265+
_global_logger.log_error(request_id, error_message)

optillm/cot_reflection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
import logging
33
import optillm
4+
from optillm import conversation_logger
45

56
logger = logging.getLogger(__name__)
67

optillm/leap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Tuple
44
import json
55
import optillm
6+
from optillm import conversation_logger
67

78
# Setup logging
89
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

optillm/mcts.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import networkx as nx
55
from typing import List, Dict
66
import optillm
7+
from optillm import conversation_logger
78

89
logger = logging.getLogger(__name__)
910

@@ -123,9 +124,9 @@ def generate_actions(self, state: DialogueState) -> List[str]:
123124
response = self.client.chat.completions.create(**provider_request)
124125

125126
# Log provider call
126-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
127+
if self.request_id:
127128
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
128-
optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
129+
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
129130

130131
completions = [choice.message.content.strip() for choice in response.choices]
131132
self.completion_tokens += response.usage.completion_tokens
@@ -152,9 +153,9 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
152153
response = self.client.chat.completions.create(**provider_request)
153154

154155
# Log provider call
155-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
156+
if self.request_id:
156157
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
157-
optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
158+
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
158159

159160
next_query = response.choices[0].message.content
160161
self.completion_tokens += response.usage.completion_tokens
@@ -182,9 +183,9 @@ def evaluate_state(self, state: DialogueState) -> float:
182183
response = self.client.chat.completions.create(**provider_request)
183184

184185
# Log provider call
185-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
186+
if self.request_id:
186187
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
187-
optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
188+
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
188189

189190
self.completion_tokens += response.usage.completion_tokens
190191
try:

optillm/moa.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import optillm
3+
from optillm import conversation_logger
34

45
logger = logging.getLogger(__name__)
56

@@ -29,8 +30,8 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
2930
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
3031

3132
# Log provider call if conversation logging is enabled
32-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
33-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
33+
if request_id:
34+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
3435

3536
completions = [choice.message.content for choice in response.choices]
3637
moa_completion_tokens += response.usage.completion_tokens
@@ -60,8 +61,8 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
6061
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
6162

6263
# Log provider call if conversation logging is enabled
63-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
64-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
64+
if request_id:
65+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
6566

6667
completions.append(response.choices[0].message.content)
6768
moa_completion_tokens += response.usage.completion_tokens
@@ -122,8 +123,8 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
122123
response_dict = critique_response.model_dump() if hasattr(critique_response, 'model_dump') else critique_response
123124

124125
# Log provider call if conversation logging is enabled
125-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
126-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
126+
if request_id:
127+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
127128

128129
critiques = critique_response.choices[0].message.content
129130
moa_completion_tokens += critique_response.usage.completion_tokens
@@ -169,8 +170,8 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
169170
response_dict = final_response.model_dump() if hasattr(final_response, 'model_dump') else final_response
170171

171172
# Log provider call if conversation logging is enabled
172-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
173-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
173+
if request_id:
174+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
174175

175176
moa_completion_tokens += final_response.usage.completion_tokens
176177
logger.info(f"Generated final response. Tokens used: {final_response.usage.completion_tokens}")

optillm/plansearch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import List, Tuple
33
import optillm
4+
from optillm import conversation_logger
45

56
logger = logging.getLogger(__name__)
67

optillm/pvg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
from typing import List, Tuple
44
import optillm
5+
from optillm import conversation_logger
56

67
logger = logging.getLogger(__name__)
78

0 commit comments

Comments
 (0)