55import pytest
66from flyte .io import DataFrame
77from flyteidl2 .core .execution_pb2 import TaskExecution
8- from flyteidl2 .core .interface_pb2 import Variable , VariableMap
8+ from flyteidl2 .core .interface_pb2 import Variable , VariableEntry , VariableMap
99from flyteidl2 .core .tasks_pb2 import Sql , TaskTemplate
1010from flyteidl2 .core .types_pb2 import LiteralType , SimpleType
1111from google .protobuf import struct_pb2
@@ -72,13 +72,14 @@ def task_template_with_inputs(self):
7272 template .sql .CopyFrom (Sql (statement = "SELECT * FROM table WHERE id = @user_id" , dialect = Sql .Dialect .ANSI ))
7373 template .metadata .runtime .version = "1.0.0"
7474
75- # Add input variables
75+ # Add input variables using the new list-based structure
7676 int_type = LiteralType ()
7777 int_type .simple = SimpleType .INTEGER
7878 user_id_var = Variable (type = int_type )
7979
8080 variables = VariableMap ()
81- variables .variables ["user_id" ].CopyFrom (user_id_var )
81+ var_entry = VariableEntry (key = "user_id" , value = user_id_var )
82+ variables .variables .append (var_entry )
8283 template .interface .inputs .CopyFrom (variables )
8384
8485 custom = struct_pb2 .Struct ()
@@ -329,7 +330,7 @@ async def test_create_with_multiple_input_types(self, connector):
329330 )
330331 template .metadata .runtime .version = "1.0.0"
331332
332- # Add multiple input variables with different types
333+ # Add multiple input variables with different types using the new list-based structure
333334 int_type = LiteralType ()
334335 int_type .simple = SimpleType .INTEGER
335336 str_type = LiteralType ()
@@ -338,9 +339,9 @@ async def test_create_with_multiple_input_types(self, connector):
338339 bool_type .simple = SimpleType .BOOLEAN
339340
340341 variables = VariableMap ()
341- variables .variables [ "user_id" ]. CopyFrom ( Variable (type = int_type ))
342- variables .variables [ "name" ]. CopyFrom ( Variable (type = str_type ))
343- variables .variables [ "active" ]. CopyFrom ( Variable (type = bool_type ))
342+ variables .variables . append ( VariableEntry ( key = "user_id" , value = Variable (type = int_type ) ))
343+ variables .variables . append ( VariableEntry ( key = "name" , value = Variable (type = str_type ) ))
344+ variables .variables . append ( VariableEntry ( key = "active" , value = Variable (type = bool_type ) ))
344345 template .interface .inputs .CopyFrom (variables )
345346
346347 custom = struct_pb2 .Struct ()
@@ -447,3 +448,71 @@ async def test_create_with_google_application_credentials(self, connector, task_
447448 # Verify the credentials were passed to the client
448449 mock_client_class .assert_called_once ()
449450 assert mock_client_class .call_args [1 ]["credentials" ] == mock_credentials
451+
452+ @pytest .mark .asyncio
453+ async def test_create_iterates_variables_with_new_structure (self , connector ):
454+ """Test that the connector correctly iterates over variables using the new iteration pattern.
455+
456+ This test verifies the change from:
457+ for name, lt in task_template.interface.inputs.variables.items()
458+ To:
459+ for variable in task_template.interface.inputs.variables
460+
461+ The variables field changed from a map to a repeated field (list), so we now
462+ iterate directly over the list of Variable objects which have key and value attributes.
463+ """
464+ template = TaskTemplate ()
465+ template .sql .CopyFrom (
466+ Sql (
467+ statement = "SELECT * FROM table WHERE user_id = @user_id AND email = @email" ,
468+ dialect = Sql .Dialect .ANSI ,
469+ )
470+ )
471+ template .metadata .runtime .version = "2.0.0"
472+
473+ # Create variables using the new list-based VariableMap structure
474+ int_type = LiteralType ()
475+ int_type .simple = SimpleType .INTEGER
476+ str_type = LiteralType ()
477+ str_type .simple = SimpleType .STRING
478+
479+ variables = VariableMap ()
480+ variables .variables .append (VariableEntry (key = "user_id" , value = Variable (type = int_type )))
481+ variables .variables .append (VariableEntry (key = "email" , value = Variable (type = str_type )))
482+ template .interface .inputs .CopyFrom (variables )
483+
484+ custom = struct_pb2 .Struct ()
485+ custom ["ProjectID" ] = "test-project"
486+ custom ["Location" ] = "US"
487+ custom ["Domain" ] = "test-domain"
488+ template .custom .CopyFrom (custom )
489+
490+ with patch ("flyteplugins.bigquery.connector.bigquery.Client" ) as mock_client_class :
491+ mock_client = MagicMock ()
492+ mock_client_class .return_value = mock_client
493+
494+ mock_query_job = MagicMock ()
495+ mock_query_job .job_id = "job-iteration-test"
496+ mock_client .query .return_value = mock_query_job
497+
498+ inputs = {"user_id" : 42 , "email" : "test@example.com" }
499+ metadata = await connector .create (template , inputs = inputs )
500+
501+ assert metadata .job_id == "job-iteration-test"
502+
503+ # Verify that the query was called with proper parameters
504+ call_args = mock_client .query .call_args
505+ job_config = call_args [1 ]["job_config" ]
506+
507+ # The new iteration pattern should successfully create query parameters
508+ assert len (job_config .query_parameters ) == 2
509+
510+ param_dict = {p .name : p for p in job_config .query_parameters }
511+ assert "user_id" in param_dict
512+ assert "email" in param_dict
513+ assert param_dict ["user_id" ].value == 42
514+ assert param_dict ["email" ].value == "test@example.com"
515+
516+ # Verify parameter types are correctly mapped
517+ assert param_dict ["user_id" ].type_ == "INT64"
518+ assert param_dict ["email" ].type_ == "STRING"
0 commit comments