diff --git a/jabgui/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java b/jabgui/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java index c6308a77f29..18b8dc6d65f 100644 --- a/jabgui/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java +++ b/jabgui/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java @@ -22,6 +22,7 @@ import org.jabref.gui.preferences.PreferenceTabViewModel; import org.jabref.logic.ai.AiDefaultPreferences; import org.jabref.logic.ai.AiPreferences; +import org.jabref.logic.ai.models.AiModelService; import org.jabref.logic.ai.templates.AiTemplate; import org.jabref.logic.l10n.Localization; import org.jabref.logic.preferences.CliPreferences; @@ -107,6 +108,7 @@ AiTemplate.CITATION_PARSING_USER_MESSAGE, new SimpleStringProperty() private final BooleanProperty disableExpertSettings = new SimpleBooleanProperty(true); private final AiPreferences aiPreferences; + private final AiModelService aiModelService; private final Validator apiKeyValidator; private final Validator chatModelValidator; @@ -125,6 +127,7 @@ public AiTabViewModel(CliPreferences preferences) { this.oldLocale = Locale.getDefault(); this.aiPreferences = preferences.getAiPreferences(); + this.aiModelService = new AiModelService(); this.enableAi.addListener((_, _, newValue) -> { disableBasicSettings.set(!newValue); @@ -428,6 +431,38 @@ public void resetCurrentTemplate() { }); } + /** + * Fetches available models for the currently selected AI provider. + * Attempts to fetch models dynamically from the API, falling back to hardcoded models if fetch fails. + * This method runs asynchronously and updates the chatModelsList when complete. + */ + public void refreshAvailableModels() { + AiProvider provider = selectedAiProvider.get(); + if (provider == null) { + return; + } + + String apiKey = currentApiKey.get(); + String apiBaseUrl = customizeExpertSettings.get() ? currentApiBaseUrl.get() : provider.getApiUrl(); + + List staticModels = aiModelService.getStaticModels(provider); + chatModelsList.setAll(staticModels); + + aiModelService.fetchModelsAsync(provider, apiBaseUrl, apiKey) + .thenAccept(dynamicModels -> { + if (!dynamicModels.isEmpty()) { + javafx.application.Platform.runLater(() -> { + String currentModel = currentChatModel.get(); + chatModelsList.setAll(dynamicModels); + if (currentModel != null && !currentModel.isBlank()) { + currentChatModel.set(currentModel); + } + }); + } + }) + .exceptionally(_ -> null); + } + @Override public boolean validateSettings() { if (enableAi.get()) { diff --git a/jablib/src/main/java/module-info.java b/jablib/src/main/java/module-info.java index dc137e314f5..3fae712cdd3 100644 --- a/jablib/src/main/java/module-info.java +++ b/jablib/src/main/java/module-info.java @@ -39,6 +39,7 @@ exports org.jabref.model.groups.event; exports org.jabref.logic.preview; exports org.jabref.logic.ai; + exports org.jabref.logic.ai.models; exports org.jabref.logic.pdf; exports org.jabref.model.database.event; exports org.jabref.model.entry.event; diff --git a/jablib/src/main/java/org/jabref/logic/ai/models/AiModelProvider.java b/jablib/src/main/java/org/jabref/logic/ai/models/AiModelProvider.java new file mode 100644 index 00000000000..63782bd9ebc --- /dev/null +++ b/jablib/src/main/java/org/jabref/logic/ai/models/AiModelProvider.java @@ -0,0 +1,29 @@ +package org.jabref.logic.ai.models; + +import java.util.List; + +import org.jabref.model.ai.AiProvider; + +/** + * Interface for fetching available AI models from different providers. + * Implementations should handle API calls to retrieve model lists dynamically. + */ +public interface AiModelProvider { + /** + * Fetches the list of available models for the given AI provider. + * + * @param aiProvider The AI provider to fetch models from + * @param apiBaseUrl The base URL for the API + * @param apiKey The API key for authentication + * @return A list of available model names + */ + List fetchModels(AiProvider aiProvider, String apiBaseUrl, String apiKey); + + /** + * Checks if this provider supports the given AI provider type. + * + * @param aiProvider The AI provider to check + * @return true if this provider can fetch models for the given AI provider + */ + boolean supports(AiProvider aiProvider); +} \ No newline at end of file diff --git a/jablib/src/main/java/org/jabref/logic/ai/models/AiModelService.java b/jablib/src/main/java/org/jabref/logic/ai/models/AiModelService.java new file mode 100644 index 00000000000..fbb40e48b9d --- /dev/null +++ b/jablib/src/main/java/org/jabref/logic/ai/models/AiModelService.java @@ -0,0 +1,95 @@ +package org.jabref.logic.ai.models; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.jabref.logic.ai.AiDefaultPreferences; +import org.jabref.model.ai.AiProvider; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Service for managing AI models from different providers. + * Provides both static (hardcoded) and dynamic (API-fetched) model lists. + */ +public class AiModelService { + private static final Logger LOGGER = LoggerFactory.getLogger(AiModelService.class); + private static final int FETCH_TIMEOUT_SECONDS = 5; + + private final List modelProviders; + + public AiModelService() { + this.modelProviders = new ArrayList<>(); + this.modelProviders.add(new OpenAiCompatibleModelProvider()); + } + + /** + * Gets the list of available models for the given provider. + * First attempts to fetch models dynamically from the API. + * If that fails or times out, falls back to the hardcoded list. + * + * @param aiProvider The AI provider + * @param apiBaseUrl The base URL for the API + * @param apiKey The API key for authentication + * @return A list of available model names + */ + public List getAvailableModels(AiProvider aiProvider, String apiBaseUrl, String apiKey) { + List dynamicModels = fetchModelsDynamically(aiProvider, apiBaseUrl, apiKey); + + if (!dynamicModels.isEmpty()) { + LOGGER.info("Using {} dynamic models for {}", dynamicModels.size(), aiProvider.getLabel()); + return dynamicModels; + } + + List staticModels = AiDefaultPreferences.getAvailableModels(aiProvider); + LOGGER.debug("Using {} hardcoded models for {}", staticModels.size(), aiProvider.getLabel()); + return staticModels; + } + + /** + * Gets the list of available models for the given provider, using only hardcoded values. + * + * @param aiProvider The AI provider + * @return A list of available model names + */ + public List getStaticModels(AiProvider aiProvider) { + return AiDefaultPreferences.getAvailableModels(aiProvider); + } + + /** + * Asynchronously fetches the list of available models from the API. + * + * @param aiProvider The AI provider + * @param apiBaseUrl The base URL for the API + * @param apiKey The API key for authentication + * @return A CompletableFuture containing the list of model names + */ + public CompletableFuture> fetchModelsAsync(AiProvider aiProvider, String apiBaseUrl, String apiKey) { + return CompletableFuture.supplyAsync(() -> fetchModelsDynamically(aiProvider, apiBaseUrl, apiKey)); + } + + private List fetchModelsDynamically(AiProvider aiProvider, String apiBaseUrl, String apiKey) { + for (AiModelProvider provider : modelProviders) { + if (provider.supports(aiProvider)) { + try { + CompletableFuture> future = CompletableFuture.supplyAsync( + () -> provider.fetchModels(aiProvider, apiBaseUrl, apiKey) + ); + + List models = future.get(FETCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + + if (models != null && !models.isEmpty()) { + return models; + } + } catch (Exception e) { + LOGGER.debug("Failed to fetch models for {}: {}", aiProvider.getLabel(), e.getMessage()); + } + } + } + + return List.of(); + } +} diff --git a/jablib/src/main/java/org/jabref/logic/ai/models/OpenAiCompatibleModelProvider.java b/jablib/src/main/java/org/jabref/logic/ai/models/OpenAiCompatibleModelProvider.java new file mode 100644 index 00000000000..ee48dadcb93 --- /dev/null +++ b/jablib/src/main/java/org/jabref/logic/ai/models/OpenAiCompatibleModelProvider.java @@ -0,0 +1,126 @@ +package org.jabref.logic.ai.models; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +import org.jabref.model.ai.AiProvider; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Model provider for OpenAI-compatible APIs. + * Fetches available models from the /v1/models endpoint. + */ +public class OpenAiCompatibleModelProvider implements AiModelProvider { + private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiCompatibleModelProvider.class); + private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(10); + + private final HttpClient httpClient; + + public OpenAiCompatibleModelProvider() { + this.httpClient = HttpClient.newBuilder() + .connectTimeout(Duration.ofSeconds(5)) + .build(); + } + + public OpenAiCompatibleModelProvider(HttpClient httpClient) { + this.httpClient = httpClient; + } + + @Override + public List fetchModels(AiProvider aiProvider, String apiBaseUrl, String apiKey) { + List models = new ArrayList<>(); + + if (apiKey == null || apiKey.isBlank()) { + LOGGER.debug("API key is not provided for {}, skipping model fetch", aiProvider.getLabel()); + return models; + } + + try { + String modelsEndpoint = buildModelsEndpoint(apiBaseUrl); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(modelsEndpoint)) + .header("Authorization", "Bearer " + apiKey) + .header("Content-Type", "application/json") + .timeout(REQUEST_TIMEOUT) + .GET() + .build(); + + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + if (response.statusCode() == 200) { + models = parseModelsFromResponse(response.body()); + LOGGER.info("Successfully fetched {} models from {}", models.size(), aiProvider.getLabel()); + } else { + LOGGER.debug("Failed to fetch models from {} (status: {})", aiProvider.getLabel(), response.statusCode()); + } + } catch (IOException | InterruptedException e) { + LOGGER.debug("Failed to fetch models from {}: {}", aiProvider.getLabel(), e.getMessage()); + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + } catch (Exception e) { + LOGGER.debug("Unexpected error while fetching models from {}: {}", aiProvider.getLabel(), e.getMessage()); + } + + return models; + } + + @Override + public boolean supports(AiProvider aiProvider) { + // OpenAI-compatible providers: OpenAI, Mistral AI, and custom OpenAI-compatible endpoints + return aiProvider == AiProvider.OPEN_AI + || aiProvider == AiProvider.MISTRAL_AI + || aiProvider == AiProvider.GPT4ALL; + } + + private String buildModelsEndpoint(String apiBaseUrl) { + String baseUrl = apiBaseUrl.trim(); + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + + if (baseUrl.endsWith("/v1")) { + return baseUrl + "/models"; + } else { + return baseUrl + "/v1/models"; + } + } + + private List parseModelsFromResponse(String responseBody) { + List models = new ArrayList<>(); + + try { + JsonObject jsonResponse = JsonParser.parseString(responseBody).getAsJsonObject(); + + if (jsonResponse.has("data") && jsonResponse.get("data").isJsonArray()) { + JsonArray modelsArray = jsonResponse.getAsJsonArray("data"); + + for (JsonElement element : modelsArray) { + if (element.isJsonObject()) { + JsonObject modelObject = element.getAsJsonObject(); + if (modelObject.has("id")) { + String modelId = modelObject.get("id").getAsString(); + models.add(modelId); + } + } + } + } + } catch (Exception e) { + LOGGER.warn("Failed to parse models response: {}", e.getMessage()); + } + + return models; + } +}