|
| 1 | +categorize_model <- function(id) { |
| 2 | + if (grepl("^gpt-[0-9.]+", id)) { |
| 3 | + match <- regmatches(id, regexpr("^gpt-[0-9.]+", id)) |
| 4 | + return(toupper(match)) # Return as "GPT-4", "GPT-3.5", etc. |
| 5 | + } |
| 6 | + if (grepl("davinci|curie|babbage|ada", id)) return("GPT-3") |
| 7 | + if (grepl("embedding", id)) return("Embedding") |
| 8 | + if (grepl("whisper|speech", id)) return("Audio") |
| 9 | + if (grepl("dall-e|image", id)) return("Image") |
| 10 | + return("Other") |
| 11 | +} |
| 12 | + |
| 13 | +# order list by category, start with models "GPT*" in decreasing order of version then other categories |
| 14 | +order_categories <- function(categories) { |
| 15 | + # Extract unique category names |
| 16 | + unique_cats <- unique(categories) |
| 17 | + |
| 18 | + # Separate GPT-* from others |
| 19 | + gpt_cats <- grep("^GPT-[0-9.]+", unique_cats, value = TRUE) |
| 20 | + other_cats <- setdiff(unique_cats, gpt_cats) |
| 21 | + |
| 22 | + # Sort GPT categories by descending version number |
| 23 | + # Convert "GPT-4" → 4.0, "GPT-3.5" → 3.5 |
| 24 | + gpt_versions <- as.numeric(sub("GPT-", "", gpt_cats)) |
| 25 | + ordered_gpt <- gpt_cats[order(-gpt_versions)] # decreasing |
| 26 | + |
| 27 | + # Final category order |
| 28 | + ordered_categories <- c(ordered_gpt, sort(other_cats)) |
| 29 | + |
| 30 | + ordered_categories |
| 31 | +} |
| 32 | + |
| 33 | +extract_named_model_list <- function(models, categories) { |
| 34 | + if (all(unique(categories) %in% c("Other"))) { |
| 35 | + return(models) |
| 36 | + } |
| 37 | + |
| 38 | + # format into named list |
| 39 | + models_list <- split(models, categories) |
| 40 | + |
| 41 | + # order list by category, start with models "GPT*" in decreasing order of version then other categories |
| 42 | + models_list <- models_list[order_categories(categories)] |
| 43 | + |
| 44 | + return(models_list) |
| 45 | +} |
| 46 | + |
| 47 | +llm_filter_config <- function(api, config) { |
| 48 | + provider <- api$provider # e.g., "OpenAI", "DeepSeek", "Ollama" |
| 49 | + |
| 50 | + supported <- switch( |
| 51 | + provider, |
| 52 | + "OpenAI" = c("model", "messages", "max_tokens", "temperature", "top_p", "n", "stop", "seed", |
| 53 | + "presence_penalty", "frequency_penalty", "logprobs"), |
| 54 | + "DeepSeek" = c("model", "messages", "max_tokens", "temperature", "top_p", "n", "stop", "seed"), |
| 55 | + "Ollama" = c("model", "messages", "max_tokens", "temperature", "top_p", "stop", "seed"), |
| 56 | + character(0) |
| 57 | + ) |
| 58 | + |
| 59 | + all_fields <- names(config) |
| 60 | + unsupported <- setdiff(all_fields, supported) |
| 61 | + |
| 62 | + result <- config[names(config) %in% supported] |
| 63 | + |
| 64 | + if (length(unsupported) > 0) { |
| 65 | + warning_msg <- sprintf("The following inputs are ignored for provider '%s': %s", |
| 66 | + provider, |
| 67 | + paste(unsupported, collapse = ", ")) |
| 68 | + warning(warning_msg, call. = FALSE) |
| 69 | + #result <- append_attr(result, warning_msg, "message") |
| 70 | + } |
| 71 | + |
| 72 | + return(result) |
| 73 | +} |
| 74 | + |
| 75 | +# Append attribute to object |
| 76 | +append_attr <- function(object, val, attr_name) { |
| 77 | + existing <- attr(object, attr_name) |
| 78 | + attr(object, attr_name) <- c(existing, val) |
| 79 | + object |
| 80 | +} |
0 commit comments