diff --git a/genai/snippets/src/main/java/genai/tuning/TuningJobCreate.java b/genai/snippets/src/main/java/genai/tuning/TuningJobCreate.java new file mode 100644 index 00000000000..c4764910814 --- /dev/null +++ b/genai/snippets/src/main/java/genai/tuning/TuningJobCreate.java @@ -0,0 +1,121 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.tuning; + +// [START googlegenaisdk_tuning_job_create] + +import static com.google.genai.types.JobState.Known.JOB_STATE_PENDING; +import static com.google.genai.types.JobState.Known.JOB_STATE_RUNNING; + +import com.google.genai.Client; +import com.google.genai.types.CreateTuningJobConfig; +import com.google.genai.types.GetTuningJobConfig; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.JobState; +import com.google.genai.types.TunedModel; +import com.google.genai.types.TunedModelCheckpoint; +import com.google.genai.types.TuningDataset; +import com.google.genai.types.TuningJob; +import com.google.genai.types.TuningValidationDataset; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +public class TuningJobCreate { + + public static void main(String[] args) throws InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String model = "gemini-2.5-flash"; + createTuningJob(model); + } + + // Shows how to create a supervised fine-tuning job using training and validation datasets + public static String createTuningJob(String model) throws InterruptedException { + // Client Initialization. Once created, it can be reused for multiple requests. + try (Client client = + Client.builder() + .location("us-central1") + .vertexAI(true) + .httpOptions(HttpOptions.builder().apiVersion("v1beta1").build()) + .build()) { + + String trainingDatasetUri = + "gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_train_data.jsonl"; + TuningDataset trainingDataset = TuningDataset.builder().gcsUri(trainingDatasetUri).build(); + + String validationDatasetUri = + "gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_validation_data.jsonl"; + TuningValidationDataset validationDataset = + TuningValidationDataset.builder().gcsUri(validationDatasetUri).build(); + + TuningJob tuningJob = + client.tunings.tune( + model, + trainingDataset, + CreateTuningJobConfig.builder() + .tunedModelDisplayName("your-display-name") + .validationDataset(validationDataset) + .build()); + + String jobName = + tuningJob.name().orElseThrow(() -> new IllegalStateException("Missing job name")); + Optional jobState = tuningJob.state(); + Set runningStates = EnumSet.of(JOB_STATE_PENDING, JOB_STATE_RUNNING); + + while (jobState.isPresent() && runningStates.contains(jobState.get().knownEnum())) { + System.out.println("Job state: " + jobState.get()); + tuningJob = client.tunings.get(jobName, GetTuningJobConfig.builder().build()); + jobState = tuningJob.state(); + TimeUnit.SECONDS.sleep(60); + } + + tuningJob.tunedModel().flatMap(TunedModel::model).ifPresent(System.out::println); + tuningJob.tunedModel().flatMap(TunedModel::endpoint).ifPresent(System.out::println); + tuningJob.experiment().ifPresent(System.out::println); + // Example response: + // projects/123456789012/locations/us-central1/models/6129850992130260992@1 + // projects/123456789012/locations/us-central1/endpoints/105055037499113472 + // projects/123456789012/locations/us-central1/metadataStores/default/contexts/experiment_id + + List checkpoints = + tuningJob.tunedModel().flatMap(TunedModel::checkpoints).orElse(Collections.emptyList()); + + int index = 0; + for (TunedModelCheckpoint checkpoint : checkpoints) { + System.out.println("Checkpoint " + (++index)); + checkpoint + .checkpointId() + .ifPresent(checkpointId -> System.out.println("checkpointId=" + checkpointId)); + checkpoint.epoch().ifPresent(epoch -> System.out.println("epoch=" + epoch)); + checkpoint.step().ifPresent(step -> System.out.println("step=" + step)); + checkpoint.endpoint().ifPresent(endpoint -> System.out.println("endpoint=" + endpoint)); + } + // Example response: + // Checkpoint 1 + // checkpointId=1 + // epoch=2 + // step=34 + // endpoint=projects/project/locations/location/endpoints/105055037499113472 + // ... + return jobName; + } + } +} +// [END googlegenaisdk_tuning_job_create] diff --git a/genai/snippets/src/main/java/genai/tuning/TuningJobGet.java b/genai/snippets/src/main/java/genai/tuning/TuningJobGet.java new file mode 100644 index 00000000000..127ac8eec30 --- /dev/null +++ b/genai/snippets/src/main/java/genai/tuning/TuningJobGet.java @@ -0,0 +1,61 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.tuning; + +// [START googlegenaisdk_tuning_job_get] + +import com.google.genai.Client; +import com.google.genai.types.GetTuningJobConfig; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.TunedModel; +import com.google.genai.types.TuningJob; +import java.util.Optional; + +public class TuningJobGet { + + public static void main(String[] args) { + // TODO(developer): Replace these variables before running the sample. + // E.g. tuningJobName = + // "projects/123456789012/locations/us-central1/tuningJobs/123456789012345" + String tuningJobName = "your-job-name"; + getTuningJob(tuningJobName); + } + + // Shows how to get a tuning job + public static Optional getTuningJob(String tuningJobName) { + // Client Initialization. Once created, it can be reused for multiple requests. + try (Client client = + Client.builder() + .location("us-central1") + .vertexAI(true) + .httpOptions(HttpOptions.builder().apiVersion("v1").build()) + .build()) { + + TuningJob tuningJob = client.tunings.get(tuningJobName, GetTuningJobConfig.builder().build()); + + tuningJob.tunedModel().flatMap(TunedModel::model).ifPresent(System.out::println); + tuningJob.tunedModel().flatMap(TunedModel::endpoint).ifPresent(System.out::println); + tuningJob.experiment().ifPresent(System.out::println); + // Example response: + // projects/123456789012/locations/us-central1/models/6129850992130260992@1 + // projects/123456789012/locations/us-central1/endpoints/105055037499113472 + // projects/123456789012/locations/us-central1/metadataStores/default/contexts/experiment_id + return tuningJob.name(); + } + } +} +// [END googlegenaisdk_tuning_job_get] diff --git a/genai/snippets/src/main/java/genai/tuning/TuningJobList.java b/genai/snippets/src/main/java/genai/tuning/TuningJobList.java new file mode 100644 index 00000000000..25b4263cf1d --- /dev/null +++ b/genai/snippets/src/main/java/genai/tuning/TuningJobList.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.tuning; + +// [START googlegenaisdk_tuning_job_list] + +import com.google.genai.Client; +import com.google.genai.Pager; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.ListTuningJobsConfig; +import com.google.genai.types.TuningJob; + +public class TuningJobList { + + public static void main(String[] args) { + listTuningJob(); + } + + // Shows how to list the available tuning jobs + public static Pager listTuningJob() { + // Client Initialization. Once created, it can be reused for multiple requests. + try (Client client = + Client.builder() + .location("us-central1") + .vertexAI(true) + .httpOptions(HttpOptions.builder().apiVersion("v1").build()) + .build()) { + + Pager tuningJobs = client.tunings.list(ListTuningJobsConfig.builder().build()); + for (TuningJob job : tuningJobs) { + job.name().ifPresent(System.out::println); + // Example response: + // projects/123456789012/locations/us-central1/tuningJobs/329583781566480384 + } + + return tuningJobs; + } + } +} +// [END googlegenaisdk_tuning_job_list] diff --git a/genai/snippets/src/main/java/genai/tuning/TuningTextGenWithTxt.java b/genai/snippets/src/main/java/genai/tuning/TuningTextGenWithTxt.java new file mode 100644 index 00000000000..6a631ff64f0 --- /dev/null +++ b/genai/snippets/src/main/java/genai/tuning/TuningTextGenWithTxt.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.tuning; + +// [START googlegenaisdk_tuning_textgen_with_txt] + +import com.google.genai.Client; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GetTuningJobConfig; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.TunedModel; +import com.google.genai.types.TuningJob; + +public class TuningTextGenWithTxt { + + public static void main(String[] args) { + // TODO(developer): Replace these variables before running the sample. + // E.g. tuningJobName = + // "projects/123456789012/locations/us-central1/tuningJobs/123456789012345" + String tuningJobName = "your-job-name"; + predictWithTunedEndpoint(tuningJobName); + } + + // Shows how to predict with a tuned model endpoint + public static String predictWithTunedEndpoint(String tuningJobName) { + // Client Initialization. Once created, it can be reused for multiple requests. + try (Client client = + Client.builder() + .location("us-central1") + .vertexAI(true) + .httpOptions(HttpOptions.builder().apiVersion("v1").build()) + .build()) { + + TuningJob tuningJob = client.tunings.get(tuningJobName, GetTuningJobConfig.builder().build()); + + String endpoint = + tuningJob + .tunedModel() + .flatMap(TunedModel::endpoint) + .orElseThrow(() -> new IllegalStateException("Missing tuned model endpoint")); + + GenerateContentResponse response = + client.models.generateContent( + endpoint, "Why is the sky blue?", GenerateContentConfig.builder().build()); + + System.out.println(response.text()); + // Example response: + // The sky is blue because of a phenomenon called Rayleigh scattering... + return response.text(); + } + } +} +// [END googlegenaisdk_tuning_textgen_with_txt] diff --git a/genai/snippets/src/test/java/genai/tuning/TuningIT.java b/genai/snippets/src/test/java/genai/tuning/TuningIT.java new file mode 100644 index 00000000000..795b7f370c8 --- /dev/null +++ b/genai/snippets/src/test/java/genai/tuning/TuningIT.java @@ -0,0 +1,191 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.tuning; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.RETURNS_SELF; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.genai.Client; +import com.google.genai.Models; +import com.google.genai.Pager; +import com.google.genai.Tunings; +import com.google.genai.types.CreateTuningJobConfig; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GetTuningJobConfig; +import com.google.genai.types.JobState; +import com.google.genai.types.ListTuningJobsConfig; +import com.google.genai.types.TunedModel; +import com.google.genai.types.TuningDataset; +import com.google.genai.types.TuningJob; +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.lang.reflect.Field; +import java.util.Iterator; +import java.util.Optional; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.MockedStatic; + +@RunWith(JUnit4.class) +public class TuningIT { + + private static final String GEMINI_FLASH = "gemini-2.5-flash"; + private ByteArrayOutputStream bout; + private PrintStream out; + private Client.Builder mockedBuilder; + private Client mockedClient; + private Tunings mockedTunings; + private TuningJob mockedResponse; + private MockedStatic mockedStatic; + + // Check if the required environment variables are set. + public static void requireEnvVar(String envVarName) { + assertWithMessage(String.format("Missing environment variable '%s' ", envVarName)) + .that(System.getenv(envVarName)) + .isNotEmpty(); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_CLOUD_PROJECT"); + } + + @Before + public void setUp() throws NoSuchFieldException, IllegalAccessException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + mockedBuilder = mock(Client.Builder.class, RETURNS_SELF); + mockedClient = mock(Client.class); + mockedTunings = mock(Tunings.class); + mockedResponse = mock(TuningJob.class); + mockedStatic = mockStatic(Client.class); + mockedStatic.when(Client::builder).thenReturn(mockedBuilder); + when(mockedBuilder.build()).thenReturn(mockedClient); + // Using reflection because 'tunings' is a final field and cannot be mockable directly + Field field = Client.class.getDeclaredField("tunings"); + field.setAccessible(true); + field.set(mockedClient, mockedTunings); + } + + @After + public void tearDown() { + System.setOut(null); + bout.reset(); + mockedStatic.close(); + } + + @Test + public void testTuningJobCreate() throws InterruptedException { + + String expectedResponse = "test-tuning-job"; + + when(mockedClient.tunings.tune( + anyString(), any(TuningDataset.class), any(CreateTuningJobConfig.class))) + .thenReturn(mockedResponse); + + TunedModel tunedModel = + TunedModel.builder().model("test-model").endpoint("test-endpoint").build(); + when(mockedResponse.name()).thenReturn(Optional.of("test-tuning-job")); + when(mockedResponse.experiment()).thenReturn(Optional.of("test-experiment")); + when(mockedResponse.tunedModel()).thenReturn(Optional.of(tunedModel)); + when(mockedResponse.state()) + .thenReturn(Optional.of(new JobState(JobState.Known.JOB_STATE_SUCCEEDED))); + + String response = TuningJobCreate.createTuningJob(GEMINI_FLASH); + + verify(mockedClient.tunings, times(1)) + .tune(anyString(), any(TuningDataset.class), any(CreateTuningJobConfig.class)); + assertThat(response).isNotEmpty(); + assertThat(response).isEqualTo(expectedResponse); + } + + @Test + public void testTuningJobGet() { + when(mockedClient.tunings.get(anyString(), any(GetTuningJobConfig.class))) + .thenReturn(mockedResponse); + when(mockedResponse.name()).thenReturn(Optional.of("test-tuning-job")); + + Optional response = TuningJobGet.getTuningJob(GEMINI_FLASH); + verify(mockedClient.tunings, times(1)).get(anyString(), any(GetTuningJobConfig.class)); + assertThat(response).isPresent(); + assertThat(response.get()).isEqualTo("test-tuning-job"); + } + + @Test + public void testTuningJobList() { + Pager mockPagerResponse = mock(Pager.class); + Iterator mockIterator = mock(Iterator.class); + + TuningJob tuningJob1 = TuningJob.builder().name("test-tuning-job1").build(); + TuningJob tuningJob2 = TuningJob.builder().name("test-tuning-job2").build(); + + when(mockedClient.tunings.list(any(ListTuningJobsConfig.class))).thenReturn(mockPagerResponse); + when(mockPagerResponse.size()).thenReturn(2); + when(mockPagerResponse.iterator()).thenReturn(mockIterator); + when(mockIterator.hasNext()).thenReturn(true, true, false); + when(mockIterator.next()).thenReturn(tuningJob1, tuningJob2); + + Pager tuningJobs = TuningJobList.listTuningJob(); + verify(mockedClient.tunings, times(1)).list(any(ListTuningJobsConfig.class)); + assertThat(tuningJobs.size()).isEqualTo(2); + assertThat(bout.toString()).isNotEmpty(); + assertThat(bout.toString()).contains("test-tuning-job1"); + assertThat(bout.toString()).contains("test-tuning-job2"); + } + + @Test + public void testTuningTextGenWithTxt() throws NoSuchFieldException, IllegalAccessException { + Models mockedModels = mock(Models.class); + // Using reflection because 'models' is a final field and cannot be mockable directly + Field field = Client.class.getDeclaredField("models"); + field.setAccessible(true); + field.set(mockedClient, mockedModels); + + when(mockedClient.tunings.get(anyString(), any(GetTuningJobConfig.class))) + .thenReturn(mockedResponse); + TunedModel tunedModel = TunedModel.builder().endpoint("test-endpoint").build(); + when(mockedResponse.tunedModel()).thenReturn(Optional.of(tunedModel)); + + GenerateContentResponse mockedGeneratedResponse = mock(GenerateContentResponse.class); + + when(mockedClient.models.generateContent( + anyString(), anyString(), any(GenerateContentConfig.class))) + .thenReturn(mockedGeneratedResponse); + when(mockedGeneratedResponse.text()).thenReturn("Example response"); + + String response = TuningTextGenWithTxt.predictWithTunedEndpoint("test-tuning-job"); + + verify(mockedClient.tunings, times(1)).get(anyString(), any(GetTuningJobConfig.class)); + verify(mockedClient.models, times(1)) + .generateContent(anyString(), anyString(), any(GenerateContentConfig.class)); + assertThat(response).isNotEmpty(); + } +}