Skip to content

Commit 47b0a18

Browse files
committed
Fix samples for 0.8.5 release
Change-Id: I48f3389565d68bc10f1608980298e6f1e384fc37
1 parent 85a027e commit 47b0a18

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

samples/tuned_models.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222

2323

2424
class UnitTests(absltest.TestCase):
25-
def test_tuned_models_create(self):
26-
# [START tuned_models_create]
25+
@classmethod
26+
def setUpClass(cls):
27+
# Code to run once before all tests in the class
28+
# [START tuned_models_create]
2729
import google.generativeai as genai
2830

2931
import time
@@ -53,7 +55,7 @@ def test_tuned_models_create(self):
5355
# You can use a tuned model here too. Set `source_model="tunedModels/..."`
5456
display_name="increment",
5557
source_model=base_model,
56-
epoch_count=20,
58+
epoch_count=5,
5759
batch_size=4,
5860
learning_rate=0.001,
5961
training_data=training_data,
@@ -62,22 +64,25 @@ def test_tuned_models_create(self):
6264
for status in operation.wait_bar():
6365
time.sleep(10)
6466

65-
result = operation.result()
66-
print(result)
67+
tuned_model = operation.result()
68+
print(tuned_model)
6769
# # You can plot the loss curve with:
6870
# snapshots = pd.DataFrame(result.tuning_task.snapshots)
6971
# sns.lineplot(data=snapshots, x='epoch', y='mean_loss')
7072

71-
model = genai.GenerativeModel(model_name=result.name)
73+
model = genai.GenerativeModel(model_name=tuned_model.name)
7274
result = model.generate_content("III")
7375
print(result.text) # IV
7476
# [END tuned_models_create]
77+
78+
cls.tuned_model_name = tuned_model_name = tuned_model.name
79+
7580

7681
def test_tuned_models_generate_content(self):
7782
# [START tuned_models_generate_content]
7883
import google.generativeai as genai
7984

80-
model = genai.GenerativeModel(model_name="tunedModels/my-increment-model")
85+
model = genai.GenerativeModel(model_name=self.tuned_model_name)
8186
result = model.generate_content("III")
8287
print(result.text) # "IV"
8388
# [END tuned_models_generate_content]
@@ -86,7 +91,7 @@ def test_tuned_models_get(self):
8691
# [START tuned_models_get]
8792
import google.generativeai as genai
8893

89-
model_info = genai.get_model("tunedModels/my-increment-model")
94+
model_info = genai.get_model(self.tuned_model_name)
9095
print(model_info)
9196
# [END tuned_models_get]
9297

@@ -100,6 +105,7 @@ def test_tuned_models_list(self):
100105

101106
def test_tuned_models_delete(self):
102107
import time
108+
import google.generativeai as genai
103109

104110
base_model = "models/gemini-1.5-flash-001-tuning"
105111
training_data = samples / "increment_tuning_data.json"
@@ -109,7 +115,7 @@ def test_tuned_models_delete(self):
109115
# You can use a tuned model here too. Set `source_model="tunedModels/..."`
110116
display_name="increment",
111117
source_model=base_model,
112-
epoch_count=20,
118+
epoch_count=5,
113119
batch_size=4,
114120
learning_rate=0.001,
115121
training_data=training_data,
@@ -135,7 +141,7 @@ def test_tuned_models_permissions_create(self):
135141
# [START tuned_models_permissions_create]
136142
import google.generativeai as genai
137143

138-
model_info = genai.get_model("tunedModels/my-increment-model")
144+
model_info = genai.get_model(self.tuned_model_name)
139145
# [START_EXCLUDE]
140146
for p in model_info.permissions.list():
141147
if p.role.name != "OWNER":
@@ -161,7 +167,7 @@ def test_tuned_models_permissions_list(self):
161167
# [START tuned_models_permissions_list]
162168
import google.generativeai as genai
163169

164-
model_info = genai.get_model("tunedModels/my-increment-model")
170+
model_info = genai.get_model(self.tuned_model_name)
165171

166172
# [START_EXCLUDE]
167173
for p in model_info.permissions.list():
@@ -190,7 +196,7 @@ def test_tuned_models_permissions_get(self):
190196
# [START tuned_models_permissions_get]
191197
import google.generativeai as genai
192198

193-
model_info = genai.get_model("tunedModels/my-increment-model")
199+
model_info = genai.get_model(self.tuned_model_name)
194200

195201
# [START_EXCLUDE]
196202
for p in model_info.permissions.list():
@@ -214,7 +220,7 @@ def test_tuned_models_permissions_update(self):
214220
# [START tuned_models_permissions_update]
215221
import google.generativeai as genai
216222

217-
model_info = genai.get_model("tunedModels/my-increment-model")
223+
model_info = genai.get_model(self.tuned_model_name)
218224

219225
# [START_EXCLUDE]
220226
for p in model_info.permissions.list():
@@ -235,7 +241,7 @@ def test_tuned_models_permission_delete(self):
235241
# [START tuned_models_permissions_delete]
236242
import google.generativeai as genai
237243

238-
model_info = genai.get_model("tunedModels/my-increment-model")
244+
model_info = genai.get_model(self.tuned_model_name)
239245
# [START_EXCLUDE]
240246
for p in model_info.permissions.list():
241247
if p.role.name != "OWNER":

0 commit comments

Comments
 (0)