13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
import copy
16
+ from collections .abc import Iterable
16
17
import datetime
17
18
import dataclasses
18
19
import pathlib
@@ -42,6 +43,8 @@ def setUp(self):
42
43
43
44
client ._client_manager .model_client = self .client
44
45
46
+ # TODO(markdaoust): Check if typechecking works better if wee define this as a
47
+ # subclass of `glm.ModelServiceClient`, would pyi files for `glm` help?
45
48
def add_client_method (f ):
46
49
name = f .__name__
47
50
setattr (self .client , name , f )
@@ -70,10 +73,6 @@ def get_tuned_model(
70
73
response = copy .copy (self .responses ["get_tuned_model" ])
71
74
return response
72
75
73
- @dataclasses .dataclass
74
- class ListWrapper :
75
- _response : Any
76
-
77
76
@add_client_method
78
77
def list_models (
79
78
request : Union [glm .ListModelsRequest , None ] = None ,
@@ -85,25 +84,25 @@ def list_models(
85
84
request = glm .ListModelsRequest (page_size = page_size , page_token = page_token )
86
85
self .assertIsInstance (request , glm .ListModelsRequest )
87
86
self .observed_requests .append (request )
88
- response = self .responses ["list_models" ][ request . page_token ]
89
- return ListWrapper ( response )
87
+ response = self .responses ["list_models" ]
88
+ return ( item for item in response )
90
89
91
90
@add_client_method
92
91
def list_tuned_models (
93
92
request : glm .ListTunedModelsRequest = None ,
94
93
* ,
95
94
page_size = None ,
96
95
page_token = None ,
97
- ) -> glm .ListModelsResponse :
96
+ ) -> Iterable [ glm .TunedModel ] :
98
97
if request is None :
99
98
request = glm .ListTunedModelsRequest (page_size = page_size , page_token = page_token )
100
99
self .assertIsInstance (request , glm .ListTunedModelsRequest )
101
100
self .observed_requests .append (request )
102
- response = self .responses ["list_tuned_models" ][ request . page_token ]
103
- return ListWrapper ( response )
101
+ response = self .responses ["list_tuned_models" ]
102
+ return ( item for item in response )
104
103
105
104
@add_client_method
106
- def update_tuned_model (request : glm .UpdateTunedModelRequest ):
105
+ def update_tuned_model (request : glm .UpdateTunedModelRequest ) -> glm . TunedModel :
107
106
self .observed_requests .append (request )
108
107
response = self .responses .get ("update_tuned_model" , None )
109
108
if response is None :
@@ -156,24 +155,13 @@ def test_fail_with_unscoped_model_name(self, name):
156
155
model = models .get_model (name )
157
156
158
157
def test_list_models (self ):
158
+ # The low level lib wraps the response in an iterable, so this is a fair test.
159
159
self .responses = {
160
- "list_models" : {
161
- # The first request doesn't pass a page token
162
- "" : glm .ListModelsResponse (
163
- models = [
164
- glm .Model (name = "models/fake-bison-001" ),
165
- glm .Model (name = "models/fake-bison-002" ),
166
- ],
167
- next_page_token = "page1" ,
168
- ),
169
- "page1" : glm .ListModelsResponse (
170
- models = [
171
- glm .Model (name = "models/fake-bison-003" ),
172
- ],
173
- # The last page returns an empty page token.
174
- next_page_token = "" ,
175
- ),
176
- }
160
+ "list_models" : [
161
+ glm .Model (name = "models/fake-bison-001" ),
162
+ glm .Model (name = "models/fake-bison-002" ),
163
+ glm .Model (name = "models/fake-bison-003" ),
164
+ ]
177
165
}
178
166
179
167
found_models = list (models .list_models ())
@@ -183,23 +171,12 @@ def test_list_models(self):
183
171
184
172
def test_list_tuned_models (self ):
185
173
self .responses = {
186
- "list_tuned_models" : {
187
- # The first request doesn't pass a page token
188
- "" : glm .ListTunedModelsResponse (
189
- tuned_models = [
190
- glm .TunedModel (name = "tunedModels/my-pig-001" ),
191
- glm .TunedModel (name = "tunedModels/my-pig-002" ),
192
- ],
193
- next_page_token = "page1" ,
194
- ),
195
- "page1" : glm .ListTunedModelsResponse (
196
- tuned_models = [
197
- glm .TunedModel (name = "tunedModels/my-pig-003" ),
198
- ],
199
- # The last page returns an empty page token.
200
- next_page_token = "" ,
201
- ),
202
- }
174
+ # The low level lib wraps the response in an iterable, so this is a fair test.
175
+ "list_tuned_models" : [
176
+ glm .TunedModel (name = "tunedModels/my-pig-001" ),
177
+ glm .TunedModel (name = "tunedModels/my-pig-002" ),
178
+ glm .TunedModel (name = "tunedModels/my-pig-003" ),
179
+ ]
203
180
}
204
181
found_models = list (models .list_tuned_models ())
205
182
self .assertLen (found_models , 3 )
0 commit comments