@@ -32,70 +32,6 @@ impl PyIterProtocol for PyEmbeddingIterator {
3232 let embeddings = slf. embeddings . borrow ( ) ;
3333 let vocab = embeddings. vocab ( ) ;
3434
35- if slf. idx < vocab. len ( ) {
36- let word = vocab. words ( ) [ slf. idx ] . to_string ( ) ;
37- let embed = embeddings. storage ( ) . embedding ( slf. idx ) ;
38-
39- slf. idx += 1 ;
40-
41- let gil = pyo3:: Python :: acquire_gil ( ) ;
42- Ok ( Some ( PyEmbedding {
43- word,
44- embedding : embed. into_owned ( ) . into_pyarray ( gil. python ( ) ) . to_owned ( ) ,
45- } ) )
46- } else {
47- Ok ( None )
48- }
49- }
50- }
51-
52- /// A word and its embedding.
53- #[ pyclass( name=Embedding ) ]
54- pub struct PyEmbedding {
55- embedding : Py < PyArray1 < f32 > > ,
56- word : String ,
57- }
58-
59- #[ pymethods]
60- impl PyEmbedding {
61- /// Get the embedding.
62- #[ getter]
63- pub fn get_embedding ( & self ) -> Py < PyArray1 < f32 > > {
64- let gil = Python :: acquire_gil ( ) ;
65- self . embedding . clone_ref ( gil. python ( ) )
66- }
67-
68- /// Get the word.
69- #[ getter]
70- pub fn get_word ( & self ) -> & str {
71- & self . word
72- }
73- }
74-
75- #[ pyclass( name=EmbeddingWithNormIterator ) ]
76- pub struct PyEmbeddingWithNormIterator {
77- embeddings : Rc < RefCell < EmbeddingsWrap > > ,
78- idx : usize ,
79- }
80-
81- impl PyEmbeddingWithNormIterator {
82- pub fn new ( embeddings : Rc < RefCell < EmbeddingsWrap > > , idx : usize ) -> Self {
83- PyEmbeddingWithNormIterator { embeddings, idx }
84- }
85- }
86-
87- #[ pyproto]
88- impl PyIterProtocol for PyEmbeddingWithNormIterator {
89- fn __iter__ ( slf : PyRefMut < Self > ) -> PyResult < Py < PyEmbeddingWithNormIterator > > {
90- Ok ( slf. into ( ) )
91- }
92-
93- fn __next__ ( mut slf : PyRefMut < Self > ) -> PyResult < Option < PyEmbeddingWithNorm > > {
94- let slf = & mut * slf;
95-
96- let embeddings = slf. embeddings . borrow ( ) ;
97- let vocab = embeddings. vocab ( ) ;
98-
9935 if slf. idx < vocab. len ( ) {
10036 let word = vocab. words ( ) [ slf. idx ] . to_string ( ) ;
10137 let embed = embeddings. storage ( ) . embedding ( slf. idx ) ;
@@ -104,7 +40,7 @@ impl PyIterProtocol for PyEmbeddingWithNormIterator {
10440 slf. idx += 1 ;
10541
10642 let gil = pyo3:: Python :: acquire_gil ( ) ;
107- Ok ( Some ( PyEmbeddingWithNorm {
43+ Ok ( Some ( PyEmbedding {
10844 word,
10945 embedding : embed. into_owned ( ) . into_pyarray ( gil. python ( ) ) . to_owned ( ) ,
11046 norm,
@@ -116,15 +52,15 @@ impl PyIterProtocol for PyEmbeddingWithNormIterator {
11652}
11753
11854/// A word and its embedding and embedding norm.
119- #[ pyclass( name=EmbeddingWithNorm ) ]
120- pub struct PyEmbeddingWithNorm {
55+ #[ pyclass( name=Embedding ) ]
56+ pub struct PyEmbedding {
12157 embedding : Py < PyArray1 < f32 > > ,
12258 norm : f32 ,
12359 word : String ,
12460}
12561
12662#[ pymethods]
127- impl PyEmbeddingWithNorm {
63+ impl PyEmbedding {
12864 /// Get the embedding.
12965 #[ getter]
13066 pub fn get_embedding ( & self ) -> Py < PyArray1 < f32 > > {
0 commit comments