Skip to content

Commit f63a706

Browse files
committed
compile succeed
1 parent b805ed0 commit f63a706

File tree

7 files changed

+17
-21
lines changed

7 files changed

+17
-21
lines changed

cpp/include/cuw2v/cuda_w2v_base_kernels.cuh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
#pragma once
77
#include "utils/cuda_utils_kernels.cuh"
88

9-
using thrust::random::default_random_engine;
10-
using thrust::random::uniform_int_distribution;
11-
129
namespace cusim {
1310

1411

cpp/include/cuw2v/cuw2v.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,6 @@ using thrust::random::default_random_engine;
3434

3535
namespace cusim {
3636

37-
bool CompareIndex(int lhs, int rhs);
38-
39-
struct HuffmanTreeNode {
40-
float count;
41-
int index, left, right;
42-
HuffmanTreeNode(float count0, int index0, int left0, int right0) {
43-
count = count0; index = index0; left = left0; right = right0;
44-
}
45-
};
46-
47-
std::vector<HuffmanTreeNode> huffman_nodes;
48-
bool CompareIndex(int lhs, int rhs);
49-
5037
class CuW2V {
5138
public:
5239
CuW2V();

cpp/include/utils/cuda_utils_kernels.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
#include <utility>
2424
#include "utils/types.hpp"
2525

26+
using thrust::random::default_random_engine;
27+
using thrust::random::uniform_int_distribution;
28+
2629
namespace cusim {
2730

2831
// Error Checking utilities, checks status codes from cuda calls

cpp/src/cuw2v/cuw2v.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010

1111
namespace cusim {
1212

13+
struct HuffmanTreeNode {
14+
float count;
15+
int index, left, right;
16+
HuffmanTreeNode(float count0, int index0, int left0, int right0) {
17+
count = count0; index = index0; left = left0; right = right0;
18+
}
19+
};
20+
21+
std::vector<HuffmanTreeNode> huffman_nodes;
1322
bool CompareIndex(int lhs, int rhs) {
1423
return huffman_nodes[lhs].count > huffman_nodes[rhs].count;
1524
}

cpp/src/utils/ioutils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ void IoUtils::GetWordVocab(int min_count, std::string keys_path, std::string cou
175175
line = std::to_string(word_count_[word_list_[i]]) + "\n";
176176
fout2.write(line.c_str(), line.size());
177177
}
178-
fout.close();
178+
fout1.close(); fout2.close();
179179
}
180180

181181
} // namespace cusim

cusim/cuw2v/bindings.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class CuW2VBind {
2929
float_array _emb_out(emb_out);
3030
auto emb_in_buffer = _emb_in.request();
3131
auto emb_out_buffer = _emb_out.request();
32-
if (emb_in_buffer.ndim != 2 or emb_out_buffer_buffer.ndim != 2 or
32+
if (emb_in_buffer.ndim != 2 or emb_out_buffer.ndim != 2 or
3333
emb_in_buffer.shape[1] != emb_out_buffer.shape[1]) {
3434
throw std::runtime_error("invalid emb_in or emb_out");
3535
}
@@ -39,7 +39,7 @@ class CuW2VBind {
3939

4040
void BuildRandomTable(py::object& word_count, int table_size, int num_threads) {
4141
float_array _word_count(word_count);
42-
auto wc_buffer = _word_count.requiest();
42+
auto wc_buffer = _word_count.request();
4343
if (wc_buffer.ndim != 1) {
4444
throw std::runtime_error("invalid word count");
4545
}
@@ -49,7 +49,7 @@ class CuW2VBind {
4949

5050
void BuildHuffmanTree(py::object& word_count) {
5151
float_array _word_count(word_count);
52-
auto wc_buffer = _word_count.requiest();
52+
auto wc_buffer = _word_count.request();
5353
if (wc_buffer.ndim != 1) {
5454
throw std::runtime_error("invalid word count");
5555
}

cusim/ioutils/bindings.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ PYBIND11_PLUGIN(ioutils_bind) {
6262
.def("tokenize_stream", &IoUtilsBind::TokenizeStream,
6363
py::arg("num_lines"), py::arg("num_threads"))
6464
.def("get_word_vocab", &IoUtilsBind::GetWordVocab,
65-
py::arg("min_count"), py::arg("keys_path"), py::Arg("count_path"))
65+
py::arg("min_count"), py::arg("keys_path"), py::arg("count_path"))
6666
.def("get_token", &IoUtilsBind::GetToken,
6767
py::arg("indices"), py::arg("indptr"), py::arg("offset"))
6868
.def("__repr__",

0 commit comments

Comments
 (0)