Skip to content

NN Built-In for Embedding Layers#2237

Closed
MaximilianSchreff wants to merge 5 commits intoapache:mainfrom
MaximilianSchreff:embedding_layer2
Closed

NN Built-In for Embedding Layers#2237
MaximilianSchreff wants to merge 5 commits intoapache:mainfrom
MaximilianSchreff:embedding_layer2

Conversation

@MaximilianSchreff
Copy link
Contributor

This PR adds the embedding layer as a built-in operator in our nn/layers library. The functionality is similar to pytorch.nn.Embedding (https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)

The layer receives indices as input which refer to indices of an embedding dictionary and returns an embedding matrix where row i refers to embedding vector indices[i] of the embedding dictionary.

This layer is used in every transformer architecture. Here the indices usually come from a tokenizer and the embedding matrix is the input to the actual transformer model.

Testing

  • Testing forward pass and backward pass for correctness
  • Implemented as a component test in NNComponentTest.java
  • Manually calculated test cases for the forward pass
  • For backward pass, comparison against pytorches autograd module

@codecov
Copy link

codecov bot commented Feb 25, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 72.47%. Comparing base (78b23cf) to head (336ef19).
Report is 22 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff            @@
##               main    #2237   +/-   ##
=========================================
  Coverage     72.46%   72.47%           
- Complexity    45453    45465   +12     
=========================================
  Files          1469     1469           
  Lines        170893   170893           
  Branches      33325    33325           
=========================================
+ Hits         123846   123863   +17     
+ Misses        37630    37617   -13     
+ Partials       9417     9413    -4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@phaniarnab
Copy link
Contributor

Thanks @MaximilianSchreff. I will merge it in.

@phaniarnab phaniarnab closed this in e97f410 Apr 6, 2025
@github-project-automation github-project-automation bot moved this from In Progress to Done in SystemDS PR Queue Apr 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants