Skip to content

Commit dffa7f5

Browse files
committed
Fix BigQuery connector variable iteration for protobuf changes
- Update BigQueryConnector to iterate over VariableEntry list instead of dict - Fix test fixtures to use VariableEntry with key/value attributes - Add comprehensive test for new variable iteration pattern - Update Snowflake tests for same protobuf structure change - Fix workflow and Makefile for plugin testing Signed-off-by: Kevin Su <pingsutw@apache.org>
1 parent 05ab2d1 commit dffa7f5

File tree

5 files changed

+87
-15
lines changed

5 files changed

+87
-15
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,5 @@ jobs:
6060

6161
- name: Run tests
6262
env:
63-
FLYTE_PLUGINS: ${{ matrix.plugin }}
63+
FLYTE_PLUGINS: plugins/${{ matrix.plugin }}
6464
run: make unit_test_plugins

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ unit_test_plugins:
7575
@for plugin in $${FLYTE_PLUGIN:-plugins/*}; do \
7676
if [ -d "$$plugin/tests" ]; then \
7777
echo "🚀 Testing plugin: $$plugin..."; \
78-
cd "$$plugin" && uv run python -m pytest tests/ && cd ../../..; \
78+
cd "$$plugin" && uv run --active python -m pytest tests/ && cd ../../..; \
7979
fi \
8080
done
8181

plugins/bigquery/src/flyteplugins/bigquery/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ async def create(
6969
job_config = None
7070
if inputs:
7171
python_interface_inputs = {
72-
name: TypeEngine.guess_python_type(lt.type)
73-
for name, lt in task_template.interface.inputs.variables.items()
72+
variable.key: TypeEngine.guess_python_type(variable.value.type)
73+
for variable in task_template.interface.inputs.variables
7474
}
7575
job_config = bigquery.QueryJobConfig(
7676
query_parameters=[

plugins/bigquery/tests/test_connector.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from flyte.io import DataFrame
77
from flyteidl2.core.execution_pb2 import TaskExecution
8-
from flyteidl2.core.interface_pb2 import Variable, VariableMap
8+
from flyteidl2.core.interface_pb2 import Variable, VariableEntry, VariableMap
99
from flyteidl2.core.tasks_pb2 import Sql, TaskTemplate
1010
from flyteidl2.core.types_pb2 import LiteralType, SimpleType
1111
from google.protobuf import struct_pb2
@@ -72,13 +72,14 @@ def task_template_with_inputs(self):
7272
template.sql.CopyFrom(Sql(statement="SELECT * FROM table WHERE id = @user_id", dialect=Sql.Dialect.ANSI))
7373
template.metadata.runtime.version = "1.0.0"
7474

75-
# Add input variables
75+
# Add input variables using the new list-based structure
7676
int_type = LiteralType()
7777
int_type.simple = SimpleType.INTEGER
7878
user_id_var = Variable(type=int_type)
7979

8080
variables = VariableMap()
81-
variables.variables["user_id"].CopyFrom(user_id_var)
81+
var_entry = VariableEntry(key="user_id", value=user_id_var)
82+
variables.variables.append(var_entry)
8283
template.interface.inputs.CopyFrom(variables)
8384

8485
custom = struct_pb2.Struct()
@@ -329,7 +330,7 @@ async def test_create_with_multiple_input_types(self, connector):
329330
)
330331
template.metadata.runtime.version = "1.0.0"
331332

332-
# Add multiple input variables with different types
333+
# Add multiple input variables with different types using the new list-based structure
333334
int_type = LiteralType()
334335
int_type.simple = SimpleType.INTEGER
335336
str_type = LiteralType()
@@ -338,9 +339,9 @@ async def test_create_with_multiple_input_types(self, connector):
338339
bool_type.simple = SimpleType.BOOLEAN
339340

340341
variables = VariableMap()
341-
variables.variables["user_id"].CopyFrom(Variable(type=int_type))
342-
variables.variables["name"].CopyFrom(Variable(type=str_type))
343-
variables.variables["active"].CopyFrom(Variable(type=bool_type))
342+
variables.variables.append(VariableEntry(key="user_id", value=Variable(type=int_type)))
343+
variables.variables.append(VariableEntry(key="name", value=Variable(type=str_type)))
344+
variables.variables.append(VariableEntry(key="active", value=Variable(type=bool_type)))
344345
template.interface.inputs.CopyFrom(variables)
345346

346347
custom = struct_pb2.Struct()
@@ -447,3 +448,71 @@ async def test_create_with_google_application_credentials(self, connector, task_
447448
# Verify the credentials were passed to the client
448449
mock_client_class.assert_called_once()
449450
assert mock_client_class.call_args[1]["credentials"] == mock_credentials
451+
452+
@pytest.mark.asyncio
453+
async def test_create_iterates_variables_with_new_structure(self, connector):
454+
"""Test that the connector correctly iterates over variables using the new iteration pattern.
455+
456+
This test verifies the change from:
457+
for name, lt in task_template.interface.inputs.variables.items()
458+
To:
459+
for variable in task_template.interface.inputs.variables
460+
461+
The variables field changed from a map to a repeated field (list), so we now
462+
iterate directly over the list of Variable objects which have key and value attributes.
463+
"""
464+
template = TaskTemplate()
465+
template.sql.CopyFrom(
466+
Sql(
467+
statement="SELECT * FROM table WHERE user_id = @user_id AND email = @email",
468+
dialect=Sql.Dialect.ANSI,
469+
)
470+
)
471+
template.metadata.runtime.version = "2.0.0"
472+
473+
# Create variables using the new list-based VariableMap structure
474+
int_type = LiteralType()
475+
int_type.simple = SimpleType.INTEGER
476+
str_type = LiteralType()
477+
str_type.simple = SimpleType.STRING
478+
479+
variables = VariableMap()
480+
variables.variables.append(VariableEntry(key="user_id", value=Variable(type=int_type)))
481+
variables.variables.append(VariableEntry(key="email", value=Variable(type=str_type)))
482+
template.interface.inputs.CopyFrom(variables)
483+
484+
custom = struct_pb2.Struct()
485+
custom["ProjectID"] = "test-project"
486+
custom["Location"] = "US"
487+
custom["Domain"] = "test-domain"
488+
template.custom.CopyFrom(custom)
489+
490+
with patch("flyteplugins.bigquery.connector.bigquery.Client") as mock_client_class:
491+
mock_client = MagicMock()
492+
mock_client_class.return_value = mock_client
493+
494+
mock_query_job = MagicMock()
495+
mock_query_job.job_id = "job-iteration-test"
496+
mock_client.query.return_value = mock_query_job
497+
498+
inputs = {"user_id": 42, "email": "test@example.com"}
499+
metadata = await connector.create(template, inputs=inputs)
500+
501+
assert metadata.job_id == "job-iteration-test"
502+
503+
# Verify that the query was called with proper parameters
504+
call_args = mock_client.query.call_args
505+
job_config = call_args[1]["job_config"]
506+
507+
# The new iteration pattern should successfully create query parameters
508+
assert len(job_config.query_parameters) == 2
509+
510+
param_dict = {p.name: p for p in job_config.query_parameters}
511+
assert "user_id" in param_dict
512+
assert "email" in param_dict
513+
assert param_dict["user_id"].value == 42
514+
assert param_dict["email"].value == "test@example.com"
515+
516+
# Verify parameter types are correctly mapped
517+
assert param_dict["user_id"].type_ == "INT64"
518+
assert param_dict["email"].type_ == "STRING"

plugins/snowflake/tests/test_connector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from flyte.io import DataFrame
55
from flyteidl2.core.execution_pb2 import TaskExecution
6-
from flyteidl2.core.interface_pb2 import Variable, VariableMap
6+
from flyteidl2.core.interface_pb2 import Variable, VariableEntry, VariableMap
77
from flyteidl2.core.tasks_pb2 import Sql, TaskTemplate
88
from flyteidl2.core.types_pb2 import LiteralType, StructuredDatasetType
99
from google.protobuf import struct_pb2
@@ -104,11 +104,14 @@ def task_template_with_output(self):
104104
template.custom.CopyFrom(custom)
105105

106106
# Set output variables so has_output is True
107-
template.interface.outputs.CopyFrom(
108-
VariableMap(
109-
variables={"results": Variable(type=LiteralType(structured_dataset_type=StructuredDatasetType()))}
107+
output_vars = VariableMap()
108+
output_vars.variables.append(
109+
VariableEntry(
110+
key="results",
111+
value=Variable(type=LiteralType(structured_dataset_type=StructuredDatasetType()))
110112
)
111113
)
114+
template.interface.outputs.CopyFrom(output_vars)
112115

113116
return template
114117

0 commit comments

Comments
 (0)