|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
16 | | -def test_checkpoint(name: str) -> str: |
| 16 | +def predict_with_checkpoints(tuning_job_name: str) -> str: |
17 | 17 | # [START googlegenaisdk_tuning_with_checkpoints_test] |
18 | 18 | from google import genai |
19 | 19 | from google.genai.types import HttpOptions |
20 | 20 |
|
21 | 21 | client = genai.Client(http_options=HttpOptions(api_version="v1")) |
22 | 22 |
|
23 | 23 | # Get the tuning job and the tuned model. |
24 | | - # Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345" |
25 | | - tuning_job = client.tunings.get(name=name) |
| 24 | + # Eg. tuning_job_name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345" |
| 25 | + tuning_job = client.tunings.get(name=tuning_job_name) |
26 | 26 |
|
27 | 27 | contents = "Why is the sky blue?" |
28 | 28 |
|
29 | | - # Tests the default checkpoint |
| 29 | + # Predicts with the default checkpoint. |
30 | 30 | response = client.models.generate_content( |
31 | 31 | model=tuning_job.tuned_model.endpoint, |
32 | 32 | contents=contents, |
33 | 33 | ) |
34 | 34 | print(response.text) |
| 35 | + # Example response: |
| 36 | + # The sky is blue because ... |
35 | 37 |
|
36 | | - # Tests Checkpoint 1 |
| 38 | + # Predicts with Checkpoint 1. |
37 | 39 | checkpoint1_response = client.models.generate_content( |
38 | 40 | model=tuning_job.tuned_model.checkpoints[0].endpoint, |
39 | 41 | contents=contents, |
40 | 42 | ) |
41 | 43 | print(checkpoint1_response.text) |
| 44 | + # Example response: |
| 45 | + # The sky is blue because ... |
42 | 46 |
|
43 | | - # Tests Checkpoint 2 |
| 47 | + # Predicts with Checkpoint 2. |
44 | 48 | checkpoint2_response = client.models.generate_content( |
45 | 49 | model=tuning_job.tuned_model.checkpoints[1].endpoint, |
46 | 50 | contents=contents, |
47 | 51 | ) |
48 | 52 | print(checkpoint2_response.text) |
| 53 | + # Example response: |
| 54 | + # The sky is blue because ... |
49 | 55 |
|
50 | 56 | # [END googlegenaisdk_tuning_with_checkpoints_test] |
51 | 57 | return response.text |
52 | 58 |
|
53 | 59 |
|
54 | 60 | if __name__ == "__main__": |
55 | | - tuning_job_name = input("Tuning job name: ") |
56 | | - test_checkpoint(tuning_job_name) |
| 61 | + input_tuning_job_name = input("Tuning job name: ") |
| 62 | + predict_with_checkpoints(input_tuning_job_name) |
0 commit comments