Skip to content

Commit 9165249

Browse files
committed
feat: Add multiple entity support to dbt integration
- Update CLI to accept multiple -e flags for entity columns - Update mapper and codegen for multiple entities - Add comprehensive tests for multi-entity feature views - Update documentation with examples and usage - Remove backward compatibility (feature is new) This extends the dbt integration to support FeatureViews with multiple entities, enabling use cases like transaction features keyed by both user_id and merchant_id. Fixes feast-dev#5872 Signed-off-by: yassinnouh21 <[email protected]>
1 parent 08d7077 commit 9165249

File tree

5 files changed

+320
-87
lines changed

5 files changed

+320
-87
lines changed

docs/how-to-guides/dbt-integration.md

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
**Current Limitations**:
77
- Supported data sources: BigQuery, Snowflake, and File-based sources only
8-
- Single entity per model
98
- Manual entity column specification required
109

1110
Breaking changes may occur in future releases.
@@ -185,6 +184,53 @@ driver_features_fv = FeatureView(
185184
```
186185
{% endcode %}
187186

187+
## Multiple Entity Support
188+
189+
The dbt integration supports feature views with multiple entities, useful for modeling relationships involving multiple keys.
190+
191+
### Usage
192+
193+
Specify multiple entity columns using repeated `-e` flags:
194+
195+
```bash
196+
feast dbt import \
197+
-m target/manifest.json \
198+
-e user_id \
199+
-e merchant_id \
200+
--tag feast \
201+
-o features/transactions.py
202+
```
203+
204+
This creates a FeatureView with both `user_id` and `merchant_id` as entities, useful for:
205+
- Transaction features keyed by both user and merchant
206+
- Interaction features keyed by multiple parties
207+
- Association tables in many-to-many relationships
208+
209+
Single entity usage:
210+
```bash
211+
feast dbt import -m target/manifest.json -e driver_id --tag feast
212+
```
213+
214+
### Requirements
215+
216+
All specified entity columns must exist in each dbt model being imported. Models missing any entity column will be skipped with a warning.
217+
218+
### Generated Code
219+
220+
The `--output` flag generates code like:
221+
222+
```python
223+
user_id = Entity(name="user_id", join_keys=["user_id"], ...)
224+
merchant_id = Entity(name="merchant_id", join_keys=["merchant_id"], ...)
225+
226+
transaction_fv = FeatureView(
227+
name="transactions",
228+
entities=[user_id, merchant_id], # Multiple entities
229+
schema=[...],
230+
...
231+
)
232+
```
233+
188234
## CLI Reference
189235

190236
### `feast dbt list`
@@ -217,7 +263,7 @@ feast dbt import <manifest_path> [OPTIONS]
217263

218264
| Option | Description | Default |
219265
|--------|-------------|---------|
220-
| `--entity-column`, `-e` | Column to use as entity key | (required) |
266+
| `--entity-column`, `-e` | Entity column name (can be specified multiple times) | (required) |
221267
| `--data-source-type`, `-d` | Data source type: `bigquery`, `snowflake`, `file` | `bigquery` |
222268
| `--tag-filter`, `-t` | Filter models by dbt tag | None |
223269
| `--model`, `-m` | Import specific model(s) only | None |

sdk/python/feast/cli/dbt_import.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def dbt_cmd():
3030
@click.option(
3131
"--entity-column",
3232
"-e",
33+
"entity_columns",
34+
multiple=True,
3335
required=True,
34-
help="Primary key / entity column name (e.g., driver_id, customer_id)",
36+
help="Entity column name (can be specified multiple times, e.g., -e user_id -e merchant_id)",
3537
)
3638
@click.option(
3739
"--data-source-type",
@@ -89,7 +91,7 @@ def dbt_cmd():
8991
def import_command(
9092
ctx: click.Context,
9193
manifest_path: str,
92-
entity_column: str,
94+
entity_columns: tuple,
9395
data_source_type: str,
9496
timestamp_field: str,
9597
tag_filter: Optional[str],
@@ -141,6 +143,28 @@ def import_command(
141143
if parser.project_name:
142144
click.echo(f" Project: {parser.project_name}")
143145

146+
# Convert tuple to list and validate
147+
entity_cols: List[str] = list(entity_columns) if entity_columns else []
148+
149+
# Validation: At least one entity required (redundant with required=True but explicit)
150+
if not entity_cols:
151+
click.echo(
152+
f"{Fore.RED}Error: At least one entity column required{Style.RESET_ALL}",
153+
err=True,
154+
)
155+
raise SystemExit(1)
156+
157+
# Validation: No duplicate entity columns
158+
if len(entity_cols) != len(set(entity_cols)):
159+
duplicates = [col for col in entity_cols if entity_cols.count(col) > 1]
160+
click.echo(
161+
f"{Fore.RED}Error: Duplicate entity columns: {', '.join(set(duplicates))}{Style.RESET_ALL}",
162+
err=True,
163+
)
164+
raise SystemExit(1)
165+
166+
click.echo(f"Entity columns: {', '.join(entity_cols)}")
167+
144168
# Get models with filters
145169
model_list: Optional[List[str]] = list(model_names) if model_names else None
146170
models = parser.get_models(model_names=model_list, tag_filter=tag_filter)
@@ -188,24 +212,28 @@ def import_command(
188212
)
189213
continue
190214

191-
# Validate entity column exists
192-
if entity_column not in column_names:
215+
# Validate ALL entity columns exist
216+
missing_entities = [e for e in entity_cols if e not in column_names]
217+
if missing_entities:
193218
click.echo(
194219
f"{Fore.YELLOW}Warning: Model '{model.name}' missing entity "
195-
f"column '{entity_column}'. Skipping.{Style.RESET_ALL}"
220+
f"column(s): {', '.join(missing_entities)}. Skipping.{Style.RESET_ALL}"
196221
)
197222
continue
198223

199-
# Create or reuse entity
200-
if entity_column not in entities_created:
201-
entity = mapper.create_entity(
202-
name=entity_column,
203-
description="Entity key for dbt models",
204-
)
205-
entities_created[entity_column] = entity
206-
all_objects.append(entity)
207-
else:
208-
entity = entities_created[entity_column]
224+
# Create or reuse entities (one per entity column)
225+
model_entities: List[Any] = []
226+
for entity_col in entity_cols:
227+
if entity_col not in entities_created:
228+
entity = mapper.create_entity(
229+
name=entity_col,
230+
description="Entity key for dbt models",
231+
)
232+
entities_created[entity_col] = entity
233+
all_objects.append(entity)
234+
else:
235+
entity = entities_created[entity_col]
236+
model_entities.append(entity)
209237

210238
# Create data source
211239
data_source = mapper.create_data_source(
@@ -218,8 +246,8 @@ def import_command(
218246
feature_view = mapper.create_feature_view(
219247
model=model,
220248
source=data_source,
221-
entity_column=entity_column,
222-
entity=entity,
249+
entity_columns=entity_cols,
250+
entities=model_entities,
223251
timestamp_field=timestamp_field,
224252
ttl_days=ttl_days,
225253
exclude_columns=excluded,
@@ -242,7 +270,7 @@ def import_command(
242270
m
243271
for m in models
244272
if timestamp_field in [c.name for c in m.columns]
245-
and entity_column in [c.name for c in m.columns]
273+
and all(e in [c.name for c in m.columns] for e in entity_cols)
246274
]
247275

248276
# Summary
@@ -257,7 +285,7 @@ def import_command(
257285

258286
code = generate_feast_code(
259287
models=valid_models,
260-
entity_column=entity_column,
288+
entity_columns=entity_cols,
261289
data_source_type=data_source_type,
262290
timestamp_field=timestamp_field,
263291
ttl_days=ttl_days,

sdk/python/feast/dbt/codegen.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import logging
9-
from typing import Any, List, Optional, Set
9+
from typing import Any, List, Optional, Set, Union
1010

1111
from jinja2 import BaseLoader, Environment
1212

@@ -106,7 +106,7 @@
106106
{% for fv in feature_views %}
107107
{{ fv.var_name }} = FeatureView(
108108
name="{{ fv.name }}",
109-
entities=[{{ fv.entity_var }}],
109+
entities=[{{ fv.entity_vars | join(', ') }}],
110110
ttl=timedelta(days={{ fv.ttl_days }}),
111111
schema=[
112112
{% for field in fv.fields %}
@@ -220,7 +220,7 @@ def __init__(
220220
def generate(
221221
self,
222222
models: List[DbtModel],
223-
entity_column: str,
223+
entity_columns: Union[str, List[str]],
224224
manifest_path: str = "",
225225
project_name: str = "",
226226
exclude_columns: Optional[List[str]] = None,
@@ -231,7 +231,7 @@ def generate(
231231
232232
Args:
233233
models: List of DbtModel objects to generate code for
234-
entity_column: The entity/primary key column name
234+
entity_columns: Entity column name(s) - single string or list of strings
235235
manifest_path: Path to the dbt manifest (for documentation)
236236
project_name: dbt project name (for documentation)
237237
exclude_columns: Columns to exclude from features
@@ -240,25 +240,36 @@ def generate(
240240
Returns:
241241
Generated Python code as a string
242242
"""
243-
excluded = {entity_column, self.timestamp_field}
243+
# Normalize entity_columns to list
244+
entity_cols: List[str] = (
245+
[entity_columns] if isinstance(entity_columns, str) else entity_columns
246+
)
247+
248+
if not entity_cols:
249+
raise ValueError("At least one entity column must be specified")
250+
251+
excluded = set(entity_cols) | {self.timestamp_field}
244252
if exclude_columns:
245253
excluded.update(exclude_columns)
246254

247255
# Collect all Feast types used for imports
248256
type_imports: Set[str] = set()
249257

250-
# Prepare entity data
258+
# Prepare entity data - create one entity per entity column
251259
entities = []
252-
entity_var = _make_var_name(entity_column)
253-
entities.append(
254-
{
255-
"var_name": entity_var,
256-
"name": entity_column,
257-
"join_key": entity_column,
258-
"description": "Entity key for dbt models",
259-
"tags": {"source": "dbt"},
260-
}
261-
)
260+
entity_vars = [] # Track variable names for feature views
261+
for entity_col in entity_cols:
262+
entity_var = _make_var_name(entity_col)
263+
entity_vars.append(entity_var)
264+
entities.append(
265+
{
266+
"var_name": entity_var,
267+
"name": entity_col,
268+
"join_key": entity_col,
269+
"description": "Entity key for dbt models",
270+
"tags": {"source": "dbt"},
271+
}
272+
)
262273

263274
# Prepare data sources and feature views
264275
data_sources = []
@@ -269,7 +280,9 @@ def generate(
269280
column_names = [c.name for c in model.columns]
270281
if self.timestamp_field not in column_names:
271282
continue
272-
if entity_column not in column_names:
283+
284+
# Skip if ANY entity column is missing
285+
if not all(e in column_names for e in entity_cols):
273286
continue
274287

275288
# Build tags
@@ -339,7 +352,7 @@ def generate(
339352
{
340353
"var_name": fv_var,
341354
"name": model.name,
342-
"entity_var": entity_var,
355+
"entity_vars": entity_vars,
343356
"source_var": source_var,
344357
"ttl_days": self.ttl_days,
345358
"fields": fields,
@@ -366,7 +379,7 @@ def generate(
366379

367380
def generate_feast_code(
368381
models: List[DbtModel],
369-
entity_column: str,
382+
entity_columns: Union[str, List[str]],
370383
data_source_type: str = "bigquery",
371384
timestamp_field: str = "event_timestamp",
372385
ttl_days: int = 1,
@@ -380,7 +393,7 @@ def generate_feast_code(
380393
381394
Args:
382395
models: List of DbtModel objects
383-
entity_column: Primary key column name
396+
entity_columns: Entity column name(s) - single string or list of strings
384397
data_source_type: Type of data source (bigquery, snowflake, file)
385398
timestamp_field: Timestamp column name
386399
ttl_days: TTL in days for feature views
@@ -400,7 +413,7 @@ def generate_feast_code(
400413

401414
return generator.generate(
402415
models=models,
403-
entity_column=entity_column,
416+
entity_columns=entity_columns,
404417
manifest_path=manifest_path,
405418
project_name=project_name,
406419
exclude_columns=exclude_columns,

0 commit comments

Comments
 (0)