TRF model Doc vector on one word sentence totally different from the word vector itself #9536
Answered
by
adrianeboyd
oliviercwa
asked this question in
Help: Other Questions
-
How to reproduce the behaviourIt might not be a bug but I find the result very surprising import spacy
import numpy as np
from thinc.util import get_array_module
def norm(vector) -> float:
xp = get_array_module(vector)
total = (vector**2).sum()
return xp.sqrt(total) if total != 0. else 0
# Define a one word sentence
nlp_trf = spacy.load('en_core_web_trf')
doc = nlp_trf('VESSEL')
# Get doc vector
doc_vect = doc._.trf_data.tensors[-1].mean(axis=0)
# Get span vector for the full doc
span = doc[:]
tensor_ix = span.doc._.trf_data.align[span.start: span.end].data.flatten()
out_dim = span.doc._.trf_data.tensors[0].shape[-1]
tensor = span.doc._.trf_data.tensors[0].reshape(-1, out_dim)[tensor_ix]
span_vect = tensor.mean(axis=0)
# Print similarity. Expected to be close to 1. Instead it's totally disimilar
print(np.dot(doc_vect, span_vect) / (norm(doc_vect) * norm(span_vect)))
-0.013732557 # The two are totally disimilar Your Environment
|
Beta Was this translation helpful? Give feedback.
Answered by
adrianeboyd
Oct 25, 2021
Replies: 1 comment 1 reply
-
Hi, the difference is whether you're including the special tokens or not. If you treat the doc the same way as the span, you get the same results: doc_vect = doc._.trf_data.tensors[-1].mean(axis=0)
tensor_ix = doc._.trf_data.align[0: len(doc)].data.flatten()
out_dim = doc._.trf_data.tensors[0].shape[-1]
tensor = doc._.trf_data.tensors[0].reshape(-1, out_dim)[tensor_ix]
doc_vect = tensor.mean(axis=0) There are five tokens on the transformer side ( |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
oliviercwa
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, the difference is whether you're including the special tokens or not. If you treat the doc the same way as the span, you get the same results:
There are five tokens on the transformer side (
['<s>', 'V', 'ESS', 'EL', '</s>']
) and the alignment to "VESSEL" intrf_data.align
does not include the<s>
and</s>
tokens.