Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.jabref.gui.preferences.PreferenceTabViewModel;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.models.AiModelService;
import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.preferences.CliPreferences;
Expand Down Expand Up @@ -107,6 +108,7 @@ AiTemplate.CITATION_PARSING_USER_MESSAGE, new SimpleStringProperty()
private final BooleanProperty disableExpertSettings = new SimpleBooleanProperty(true);

private final AiPreferences aiPreferences;
private final AiModelService aiModelService;

private final Validator apiKeyValidator;
private final Validator chatModelValidator;
Expand All @@ -125,6 +127,7 @@ public AiTabViewModel(CliPreferences preferences) {
this.oldLocale = Locale.getDefault();

this.aiPreferences = preferences.getAiPreferences();
this.aiModelService = new AiModelService();

this.enableAi.addListener((_, _, newValue) -> {
disableBasicSettings.set(!newValue);
Expand Down Expand Up @@ -428,6 +431,38 @@ public void resetCurrentTemplate() {
});
}

/**
* Fetches available models for the currently selected AI provider.
* Attempts to fetch models dynamically from the API, falling back to hardcoded models if fetch fails.
* This method runs asynchronously and updates the chatModelsList when complete.
*/
public void refreshAvailableModels() {
AiProvider provider = selectedAiProvider.get();
if (provider == null) {
return;
}

String apiKey = currentApiKey.get();
String apiBaseUrl = customizeExpertSettings.get() ? currentApiBaseUrl.get() : provider.getApiUrl();

List<String> staticModels = aiModelService.getStaticModels(provider);
chatModelsList.setAll(staticModels);

aiModelService.fetchModelsAsync(provider, apiBaseUrl, apiKey)
.thenAccept(dynamicModels -> {
if (!dynamicModels.isEmpty()) {
javafx.application.Platform.runLater(() -> {
String currentModel = currentChatModel.get();
chatModelsList.setAll(dynamicModels);
if (currentModel != null && !currentModel.isBlank()) {
currentChatModel.set(currentModel);
}
});
}
})
.exceptionally(_ -> null);
}

@Override
public boolean validateSettings() {
if (enableAi.get()) {
Expand Down
1 change: 1 addition & 0 deletions jablib/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
exports org.jabref.model.groups.event;
exports org.jabref.logic.preview;
exports org.jabref.logic.ai;
exports org.jabref.logic.ai.models;
exports org.jabref.logic.pdf;
exports org.jabref.model.database.event;
exports org.jabref.model.entry.event;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.jabref.logic.ai.models;

import java.util.List;

import org.jabref.model.ai.AiProvider;

/**
* Interface for fetching available AI models from different providers.
* Implementations should handle API calls to retrieve model lists dynamically.
*/
public interface AiModelProvider {
/**
* Fetches the list of available models for the given AI provider.
*
* @param aiProvider The AI provider to fetch models from
* @param apiBaseUrl The base URL for the API
* @param apiKey The API key for authentication
* @return A list of available model names
*/
List<String> fetchModels(AiProvider aiProvider, String apiBaseUrl, String apiKey);

/**
* Checks if this provider supports the given AI provider type.
*
* @param aiProvider The AI provider to check
* @return true if this provider can fetch models for the given AI provider
*/
boolean supports(AiProvider aiProvider);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package org.jabref.logic.ai.models;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.model.ai.AiProvider;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Service for managing AI models from different providers.
* Provides both static (hardcoded) and dynamic (API-fetched) model lists.
*/
public class AiModelService {
private static final Logger LOGGER = LoggerFactory.getLogger(AiModelService.class);
private static final int FETCH_TIMEOUT_SECONDS = 5;

private final List<AiModelProvider> modelProviders;

public AiModelService() {
this.modelProviders = new ArrayList<>();
this.modelProviders.add(new OpenAiCompatibleModelProvider());
}

/**
* Gets the list of available models for the given provider.
* First attempts to fetch models dynamically from the API.
* If that fails or times out, falls back to the hardcoded list.
*
* @param aiProvider The AI provider
* @param apiBaseUrl The base URL for the API
* @param apiKey The API key for authentication
* @return A list of available model names
*/
public List<String> getAvailableModels(AiProvider aiProvider, String apiBaseUrl, String apiKey) {
List<String> dynamicModels = fetchModelsDynamically(aiProvider, apiBaseUrl, apiKey);

if (!dynamicModels.isEmpty()) {
LOGGER.info("Using {} dynamic models for {}", dynamicModels.size(), aiProvider.getLabel());
return dynamicModels;
}

List<String> staticModels = AiDefaultPreferences.getAvailableModels(aiProvider);
LOGGER.debug("Using {} hardcoded models for {}", staticModels.size(), aiProvider.getLabel());
return staticModels;
}

/**
* Gets the list of available models for the given provider, using only hardcoded values.
*
* @param aiProvider The AI provider
* @return A list of available model names
*/
public List<String> getStaticModels(AiProvider aiProvider) {
return AiDefaultPreferences.getAvailableModels(aiProvider);
}

/**
* Asynchronously fetches the list of available models from the API.
*
* @param aiProvider The AI provider
* @param apiBaseUrl The base URL for the API
* @param apiKey The API key for authentication
* @return A CompletableFuture containing the list of model names
*/
public CompletableFuture<List<String>> fetchModelsAsync(AiProvider aiProvider, String apiBaseUrl, String apiKey) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is too much asynchronous/parallelism.

What we use in JabRef, and what I think is enough to do is to actually write only synchronous methods, and if we need something to happen in background we create a new class that inherits from BackgroundTask and it calls a synchronous method.

So, you write everything synchronously, then you create a class FetchAiModelsBackgroundTask, and in the AiService you start that task. Could you implement an approach like this?

return CompletableFuture.supplyAsync(() -> fetchModelsDynamically(aiProvider, apiBaseUrl, apiKey));
}

private List<String> fetchModelsDynamically(AiProvider aiProvider, String apiBaseUrl, String apiKey) {
for (AiModelProvider provider : modelProviders) {
if (provider.supports(aiProvider)) {
try {
CompletableFuture<List<String>> future = CompletableFuture.supplyAsync(
() -> provider.fetchModels(aiProvider, apiBaseUrl, apiKey)
);

List<String> models = future.get(FETCH_TIMEOUT_SECONDS, TimeUnit.SECONDS);

if (models != null && !models.isEmpty()) {
return models;
}
} catch (Exception e) {
LOGGER.debug("Failed to fetch models for {}: {}", aiProvider.getLabel(), e.getMessage());
}
}
}

return List.of();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package org.jabref.logic.ai.models;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

import org.jabref.model.ai.AiProvider;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Model provider for OpenAI-compatible APIs.
* Fetches available models from the /v1/models endpoint.
*/
public class OpenAiCompatibleModelProvider implements AiModelProvider {
private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiCompatibleModelProvider.class);
private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(10);

private final HttpClient httpClient;

public OpenAiCompatibleModelProvider() {
this.httpClient = HttpClient.newBuilder()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case, can you check: does this object needs to be closed?

.connectTimeout(Duration.ofSeconds(5))
.build();
}

public OpenAiCompatibleModelProvider(HttpClient httpClient) {
this.httpClient = httpClient;
}

@Override
public List<String> fetchModels(AiProvider aiProvider, String apiBaseUrl, String apiKey) {
List<String> models = new ArrayList<>();

if (apiKey == null || apiKey.isBlank()) {
LOGGER.debug("API key is not provided for {}, skipping model fetch", aiProvider.getLabel());
return models;
}

try {
String modelsEndpoint = buildModelsEndpoint(apiBaseUrl);
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(modelsEndpoint))
.header("Authorization", "Bearer " + apiKey)
.header("Content-Type", "application/json")
.timeout(REQUEST_TIMEOUT)
.GET()
.build();

HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());

if (response.statusCode() == 200) {
models = parseModelsFromResponse(response.body());
LOGGER.info("Successfully fetched {} models from {}", models.size(), aiProvider.getLabel());
} else {
LOGGER.debug("Failed to fetch models from {} (status: {})", aiProvider.getLabel(), response.statusCode());
}
} catch (IOException | InterruptedException e) {
LOGGER.debug("Failed to fetch models from {}: {}", aiProvider.getLabel(), e.getMessage());
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
} catch (Exception e) {
LOGGER.debug("Unexpected error while fetching models from {}: {}", aiProvider.getLabel(), e.getMessage());
}

return models;
}

@Override
public boolean supports(AiProvider aiProvider) {
// OpenAI-compatible providers: OpenAI, Mistral AI, and custom OpenAI-compatible endpoints
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment (and potentially other docstrings) should be rewritten to sound like this:

"Mistral provides an API that has the same endpoint for fetching available models as OpenAI, thus this class is able to fetch models also from Mistral"

This is not ideal wording, but I hope you got the idea - so that person reading code would understand why these providers are called "OpenAI compatible". It's actually my mistake too, as I also haven't explained this properly

return aiProvider == AiProvider.OPEN_AI
|| aiProvider == AiProvider.MISTRAL_AI
|| aiProvider == AiProvider.GPT4ALL;
}

private String buildModelsEndpoint(String apiBaseUrl) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use a type that is more specific to URLs, rather than a raw string?

String baseUrl = apiBaseUrl.trim();
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
}

if (baseUrl.endsWith("/v1")) {
return baseUrl + "/models";
} else {
return baseUrl + "/v1/models";
}
}

private List<String> parseModelsFromResponse(String responseBody) {
List<String> models = new ArrayList<>();

try {
JsonObject jsonResponse = JsonParser.parseString(responseBody).getAsJsonObject();

if (jsonResponse.has("data") && jsonResponse.get("data").isJsonArray()) {
JsonArray modelsArray = jsonResponse.getAsJsonArray("data");

for (JsonElement element : modelsArray) {
if (element.isJsonObject()) {
JsonObject modelObject = element.getAsJsonObject();
if (modelObject.has("id")) {
String modelId = modelObject.get("id").getAsString();
models.add(modelId);
}
}
}
}
} catch (Exception e) {
LOGGER.warn("Failed to parse models response: {}", e.getMessage());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please read our docs on Logging.

}

return models;
}
}
Loading