Skip to content

Multi-Head Attention Layer Implementation#2172

Closed
MaximilianSchreff wants to merge 25 commits intoapache:mainfrom
MaximilianSchreff:multi-attention
Closed

Multi-Head Attention Layer Implementation#2172
MaximilianSchreff wants to merge 25 commits intoapache:mainfrom
MaximilianSchreff:multi-attention

Conversation

@MaximilianSchreff
Copy link
Contributor

This PR introduces multi-head attention layers as a built in layer with forward and backward pass.

Description

The multi-head attention layer is the base layer of all most Transformer models, with many variations for different models. This implementation is in-line with the basic BERT attention layer. The functionality is currently kept to a minimum without features like attention masking, head masking, cross-attention, etc.

Testing

  • New testing module was implemented specifically for this layer, extending automated testing base
  • Tests execute forward/backward pass with given inputs and compares outputs against expected outputs
  • Implementation is compared against HuggingFace Transformer library implementation

Notes

This PR is the first in a number of PRs in an effort to support the BERT model in SystemDS and other transformer models in the future.

@codecov
Copy link

codecov bot commented Jan 1, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 71.86%. Comparing base (082cf89) to head (4ec2c80).
Report is 30 commits behind head on main.

Additional details and impacted files
@@             Coverage Diff              @@
##               main    #2172      +/-   ##
============================================
- Coverage     71.97%   71.86%   -0.11%     
- Complexity    43855    44427     +572     
============================================
  Files          1441     1445       +4     
  Lines        166018   168173    +2155     
  Branches      32396    32827     +431     
============================================
+ Hits         119494   120865    +1371     
- Misses        37294    38019     +725     
- Partials       9230     9289      +59     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@phaniarnab
Copy link
Contributor

Thanks for the patch @MaximilianSchreff. Can you please add the missing license header to the java test file?

@MaximilianSchreff
Copy link
Contributor Author

@phaniarnab, sorry forgot that. Now added.

@phaniarnab
Copy link
Contributor

Thanks for the changes. I will merge it in. @MaximilianSchreff

@phaniarnab phaniarnab closed this in 85331dc Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants