Skip to content

[TASK] Add select_by_tag to ParallelBlock #694

@marcromeyn

Description

@marcromeyn

We would like to add capabilities to a ParallelBlock to select sub-graphs. This can be useful for instance to select the item-id embedding-table from a InputBlock. Another instance where this could be useful is to enable shared-embeddings for two-tower like models. At the moment, we would instantiate a different InputBlock per tower which doesn't allow for shared-embeddings. A way to enable this is to select the right feature-branches from a InputBlock.

As an example, let's say we have the following schema:

  • User features: user-id, last-purchase (is shared encoded with the item-id)
  • Item features: item-id
all_inputs = InputBlockV2(schema)
# This would result in a parallel-block with 2 branches:
# user-id -> EmbeddingTable(user_id)
# item-id, last-purchace -> EmbeddingTable(item_id, last_purchase)

user_inputs = all_inputs.select_by_tag(Tags.USER)
# Results in a parallel-block with 2 branches:
# user-id -> EmbeddingTable(user_id)
# last-purchace -> EmbeddingTable(item_id, last_purchase)

item_inputs = all_inputs.select_by_tag(Tags.ITEM)
# Results in a parallel-block with 2 branches:
# item-id -> EmbeddingTable(item_id, last_purchase)

As can be seen in the previous example, select_by_tag sub-selects the branches in a ParallelBlock from a feature-perspective.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions