Skip to content

Commit fae50d7

Browse files
Merge pull request #20 from cafferychen777/dev
Fix OpenRouter model handling in consensus check process
2 parents 0cb3091 + 71e1aa8 commit fae50d7

File tree

4 files changed

+149
-125
lines changed

4 files changed

+149
-125
lines changed

R/R/check_consensus.R

Lines changed: 71 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,59 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
1111
# Initialize logging
1212
write_log("\n=== Starting check_consensus function ===")
1313
write_log(sprintf("Input responses: %s", paste(round_responses, collapse = "; ")))
14-
14+
1515
# Validate input
1616
if (length(round_responses) < 2) {
1717
write_log("WARNING: Not enough responses to check consensus")
1818
return(list(reached = FALSE, consensus_proportion = 0, entropy = 0, majority_prediction = "Insufficient_Responses"))
1919
}
20-
20+
2121
# Get the formatted prompt from the dedicated function
2222
# Parameters are used in prompt template to instruct LLM on threshold values # nolint
2323
formatted_responses <- create_consensus_check_prompt(round_responses, controversy_threshold, entropy_threshold)
24-
24+
2525
# Get model analysis with retry mechanism
2626
max_retries <- 3
2727
response <- "0\n0\n0\nUnknown" # Default response in case all attempts fail
2828
success <- FALSE
29-
29+
3030
# Define models to try, in order of preference
3131
models_to_try <- c()
32-
32+
3333
# If consensus_check_model is specified, prioritize it
3434
if (!is.null(consensus_check_model)) {
3535
write_log(sprintf("Using specified consensus check model: %s", consensus_check_model))
36-
models_to_try <- c(consensus_check_model)
36+
37+
# Check if this is an OpenRouter model (contains a slash)
38+
if (grepl("/", consensus_check_model)) {
39+
# For OpenRouter models, we need to extract the base model name
40+
# Format is typically "provider/model" like "google/gemini-2.5-pro-preview-03-25"
41+
parts <- strsplit(consensus_check_model, "/")[[1]]
42+
if (length(parts) > 1) {
43+
# Use the model part (after the slash)
44+
base_model <- parts[2]
45+
write_log(sprintf("Detected OpenRouter model. Using base model name: %s", base_model))
46+
models_to_try <- c(consensus_check_model, base_model)
47+
} else {
48+
models_to_try <- c(consensus_check_model)
49+
}
50+
} else {
51+
models_to_try <- c(consensus_check_model)
52+
}
3753
}
38-
54+
3955
# Add fallback models
4056
fallback_models <- c("qwen-max-2025-01-25", "claude-3-5-sonnet-latest", "gpt-4o", "gemini-2.0-flash")
41-
# Remove the consensus_check_model from fallback_models if it's already there
42-
if (!is.null(consensus_check_model)) {
43-
fallback_models <- fallback_models[fallback_models != consensus_check_model]
44-
}
57+
# Remove any models that are already in models_to_try
58+
fallback_models <- fallback_models[!fallback_models %in% models_to_try]
4559
models_to_try <- c(models_to_try, fallback_models)
46-
60+
4761
# Try each model in order
4862
for (model_name in models_to_try) {
4963
if (success) break
50-
64+
5165
write_log(sprintf("Trying model %s for consensus check", model_name))
52-
66+
5367
# Get API key for this model
5468
api_key <- get_api_key(model_name, api_keys)
5569
if (is.null(api_key) || nchar(api_key) == 0) {
@@ -60,54 +74,54 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
6074
write_log(sprintf("ERROR: Could not determine provider for model %s: %s", model_name, e$message))
6175
return(NULL)
6276
})
63-
77+
6478
if (!is.null(provider)) {
6579
env_var <- paste0(toupper(provider), "_API_KEY")
6680
api_key <- Sys.getenv(env_var)
6781
}
6882
}
69-
83+
7084
# Skip if no API key available
7185
if (is.null(api_key) || nchar(api_key) == 0) {
7286
write_log(sprintf("No API key available for %s, skipping", model_name))
7387
next
7488
}
75-
89+
7690
# Try with current model using retry mechanism
7791
for (attempt in 1:max_retries) {
7892
write_log(sprintf("Attempt %d of %d with model %s", attempt, max_retries, model_name))
79-
93+
8094
tryCatch({
8195
# Use get_model_response which automatically selects the right processor
8296
temp_response <- get_model_response(
8397
formatted_responses,
8498
model_name,
8599
api_key
86100
)
87-
101+
88102
# Ensure response is a single string
89103
if (is.character(temp_response) && length(temp_response) > 1) {
90104
temp_response <- paste(temp_response, collapse = "\n")
91105
}
92-
106+
93107
write_log(sprintf("Successfully got response from %s", model_name))
94108
response <- temp_response
95109
success <- TRUE
96110
break
97111
}, error = function(e) {
98112
write_log(sprintf("ERROR on %s attempt %d: %s", model_name, attempt, e$message))
99-
113+
100114
if (attempt < max_retries) {
101115
wait_time <- 5 * 2^(attempt-1)
102116
write_log(sprintf("Waiting for %d seconds before next attempt...", wait_time))
103117
Sys.sleep(wait_time)
104118
}
105119
})
106-
120+
107121
if (success) break
108122
}
109123
}
110-
124+
111125
# If all models failed, return default values with warning
112126
if (!success) {
113127
# Note: We don't use a local statistical method here because simple statistical methods
@@ -119,7 +133,7 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
119133
warning("All available models failed for consensus check. Please ensure at least one model API key is valid.")
120134
return(list(reached = FALSE, consensus_proportion = 0, entropy = 0, majority_prediction = "Unknown"))
121135
}
122-
136+
123137
# Directly parse the response using a simpler approach
124138
# First, check if the response contains newlines
125139
if (grepl("\n", response)) {
@@ -131,12 +145,12 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
131145
# If no newlines, treat as a single line
132146
lines <- c(response)
133147
}
134-
148+
135149
# Get the last 4 non-empty lines, as the model might include explanatory text before the actual results
136150
if (length(lines) >= 4) {
137151
# Extract the last 4 lines
138152
result_lines <- tail(lines, 4)
139-
153+
140154
# First try to process the standard four-line format
141155
# Check if it's a standard four-line format
142156
standard_format <- FALSE
@@ -147,36 +161,36 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
147161
is_line2_valid <- grepl("^\\s*(0\\.\\d+|1\\.0*|1)\\s*$", result_lines[2])
148162
# Check if the third line is a non-negative number
149163
is_line3_valid <- grepl("^\\s*(\\d+\\.\\d+|\\d+)\\s*$", result_lines[3])
150-
164+
151165
if (is_line1_valid && is_line2_valid && is_line3_valid) {
152166
standard_format <- TRUE
153167
write_log("Detected standard 4-line format")
154-
168+
155169
# Extract consensus indicator
156170
consensus_value <- as.numeric(trimws(result_lines[1]))
157171
consensus <- consensus_value == 1
158-
172+
159173
# Extract consensus proportion
160174
consensus_proportion <- as.numeric(trimws(result_lines[2]))
161175
proportion_found <- TRUE
162176
write_log(sprintf("Found proportion value %f in standard format line 2", consensus_proportion))
163-
177+
164178
# Extract entropy value
165179
entropy <- as.numeric(trimws(result_lines[3]))
166180
entropy_found <- TRUE
167181
write_log(sprintf("Found entropy value %f in standard format line 3", entropy))
168-
182+
169183
# Extract majority prediction result
170184
majority_prediction <- trimws(result_lines[4])
171185
}
172186
}
173-
187+
174188
# Only execute complex parsing logic when not in standard format
175189
if (!standard_format) {
176190
# Try to find the most likely numeric values in the last 4 lines
177191
# Look for lines that start with a number or are just a number
178192
numeric_pattern <- "^\\s*([01]|0\\.[0-9]+|1\\.[0-9]+)\\s*$"
179-
193+
180194
# Check if the first line is a valid consensus indicator (0 or 1)
181195
if (grepl("^\\s*[01]\\s*$", result_lines[1])) {
182196
consensus_value <- as.numeric(trimws(result_lines[1]))
@@ -199,11 +213,11 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
199213
consensus <- FALSE
200214
}
201215
}
202-
216+
203217
# Look for a proportion value (between 0 and 1)
204218
if (!exists("proportion_found") || !proportion_found) {
205219
proportion_found <- FALSE
206-
220+
207221
for (i in seq_along(lines)) {
208222
if (grepl("(C|c)onsensus (P|p)roportion", lines[i]) && grepl("=", lines[i])) {
209223
# Extract all content after the equals sign
@@ -226,17 +240,17 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
226240
}
227241
}
228242
}
229-
243+
230244
if (!proportion_found) {
231245
write_log("WARNING: Invalid consensus proportion, setting to 0")
232246
consensus_proportion <- 0
233247
}
234248
}
235-
249+
236250
# Look for an entropy value (a non-negative number, often with decimal places)
237251
if (!exists("entropy_found") || !entropy_found) {
238252
entropy_found <- FALSE
239-
253+
240254
for (i in seq_along(lines)) {
241255
if (grepl("(E|e)ntropy", lines[i]) && grepl("=", lines[i])) {
242256
# Extract all content after the equals sign
@@ -259,45 +273,45 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
259273
}
260274
}
261275
}
262-
276+
263277
if (!entropy_found) {
264278
write_log("WARNING: Invalid entropy, setting to 0")
265279
entropy <- 0
266280
}
267281
}
268-
282+
269283
# Look for the majority prediction (a non-numeric line)
270284
numeric_pattern <- "^\\s*\\d+(\\.\\d+)?\\s*$"
271285
majority_prediction <- NULL
272-
286+
273287
# First try to extract from the last four lines
274288
for (i in 1:4) {
275-
if (!grepl(numeric_pattern, result_lines[i]) &&
276-
!grepl("0\\.\\d+|1\\.0*|1", result_lines[i]) &&
289+
if (!grepl(numeric_pattern, result_lines[i]) &&
290+
!grepl("0\\.\\d+|1\\.0*|1", result_lines[i]) &&
277291
!grepl("\\d+\\.\\d+|\\d+", result_lines[i])) {
278292
# This line doesn't match any numeric pattern, likely the cell type
279293
majority_prediction <- trimws(result_lines[i])
280294
break
281295
}
282296
}
283297
}
284-
298+
285299
# If we still don't have a majority prediction, search all lines
286300
if (is.null(majority_prediction)) {
287301
for (i in seq_along(lines)) {
288-
if (!grepl(numeric_pattern, lines[i]) &&
289-
!grepl("(C|c)onsensus", lines[i]) &&
290-
!grepl("(E|e)ntropy", lines[i]) &&
291-
!grepl("^\\s*[01]\\s*$", lines[i]) &&
292-
!grepl("^\\s*(0\\.\\d+|1\\.0*|1)\\s*$", lines[i]) &&
293-
!grepl("^\\s*(\\d+\\.\\d+|\\d+)\\s*$", lines[i]) &&
302+
if (!grepl(numeric_pattern, lines[i]) &&
303+
!grepl("(C|c)onsensus", lines[i]) &&
304+
!grepl("(E|e)ntropy", lines[i]) &&
305+
!grepl("^\\s*[01]\\s*$", lines[i]) &&
306+
!grepl("^\\s*(0\\.\\d+|1\\.0*|1)\\s*$", lines[i]) &&
307+
!grepl("^\\s*(\\d+\\.\\d+|\\d+)\\s*$", lines[i]) &&
294308
nchar(trimws(lines[i])) > 0) {
295309
majority_prediction <- trimws(lines[i])
296310
break
297311
}
298312
}
299313
}
300-
314+
301315
if (is.null(majority_prediction)) {
302316
write_log("WARNING: Could not find a valid majority prediction")
303317
majority_prediction <- "Parsing_Failed"
@@ -309,14 +323,14 @@ check_consensus <- function(round_responses, api_keys = NULL, controversy_thresh
309323
entropy <- 0
310324
majority_prediction <- "Unknown"
311325
}
312-
326+
313327
# Return the results
314-
write_log(sprintf("Final results: consensus=%s, proportion=%f, entropy=%f, majority=%s",
315-
ifelse(consensus, "TRUE", "FALSE"),
316-
consensus_proportion,
317-
entropy,
328+
write_log(sprintf("Final results: consensus=%s, proportion=%f, entropy=%f, majority=%s",
329+
ifelse(consensus, "TRUE", "FALSE"),
330+
consensus_proportion,
331+
entropy,
318332
majority_prediction))
319-
333+
320334
return(list(
321335
reached = consensus,
322336
consensus_proportion = consensus_proportion,

0 commit comments

Comments
 (0)