Skip to content

Commit 399f7e3

Browse files
authored
Merge pull request #83 from marty1885/master
Fix! More fixes!!
2 parents 1db8ea7 + ce93e5f commit 399f7e3

File tree

7 files changed

+22
-16
lines changed

7 files changed

+22
-16
lines changed

Etaler/Algorithms/TemporalMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void TemporalMemory::learn(const Tensor& active_cells, const Tensor& last_active
3131
Tensor learning_cells = reverseBurst(active_cells);
3232

3333
learnCorrilation(last_active, learning_cells, connections_, permanences_, permanence_inc_, permanence_dec_);
34-
growSynapses(last_active, learning_cells, connections_, permanences_, 0.21);
34+
growSynapses(last_active, learning_cells, connections_, permanences_, initial_permanence_);
3535

3636
}
3737

Etaler/Algorithms/TemporalMemory.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ struct ETALER_EXPORT TemporalMemory
3131
void setActiveThreshold(size_t thr) { active_threshold_ = thr; }
3232
size_t activeThreshold() const { return active_threshold_; }
3333

34-
size_t cellsPerColumn() const {return connections_.shape().back();}
34+
size_t cellsPerColumn() const {return connections_.shape()[connections_.size()-2];}
35+
size_t maxSynapsesPerCell() const {return connections_.shape().back();}
36+
37+
float initialPermanence() const {return initial_permanence_;}
38+
void setInitialPermanence(float p) {initial_permanence_ = p;}
3539

3640
Tensor connections() const {return connections_;}
3741
Tensor permanences() const {return permanences_;}
@@ -53,10 +57,11 @@ struct ETALER_EXPORT TemporalMemory
5357
void loadState(const StateDict& states);
5458

5559
Shape input_shape_;
56-
float connected_permanence_ = 0.1;
60+
float connected_permanence_ = 0.15;
5761
size_t active_threshold_ = 2;
5862
float permanence_inc_ = 0.1;
5963
float permanence_dec_ = 0.1;
64+
float initial_permanence_ = 0.21;
6065
Tensor connections_;
6166
Tensor permanences_;
6267
};

Etaler/Backends/CPUBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ void learnCorrilation(const TensorImpl* x, const TensorImpl* learn, const Tensor
164164
else
165165
perm -= perm_dec;
166166

167-
perm = std::max(std::min(perm, PermType(1)), PermType(0));
167+
perm = std::clamp(perm, PermType(0), PermType(1));
168168
}
169169
});
170170
}

Etaler/Backends/OpenCLBackend.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ static std::string readAll(const std::string& path)
2424
return str;
2525
}
2626

27-
inline size_t selectWorkSize(size_t max, size_t mul_of, size_t size)
27+
inline intmax_t selectWorkSize(intmax_t max, intmax_t mul_of, intmax_t size)
2828
{
2929
auto round = [mul_of](auto v){return ((v/mul_of)*mul_of) + (v%mul_of == 0 ? 0 : mul_of);};
30-
return std::min((size_t)max, round(size));
30+
return std::min((intmax_t)max, round(size));
3131
}
3232

3333

@@ -488,8 +488,8 @@ std::shared_ptr<TensorImpl> OpenCLBackend::reverseBurst(const TensorImpl* x)
488488

489489
auto res = copy(x);
490490

491-
size_t local_size = 128;
492-
size_t global_size = selectWorkSize(4096, local_size, num_columns);
491+
intmax_t local_size = 128;
492+
intmax_t global_size = selectWorkSize(4096, local_size, num_columns);
493493
std::vector<uint32_t> seed1(global_size);
494494
std::vector<uint32_t> seed2(global_size);
495495

Etaler/Core/Tensor.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static size_t prettyPrintTensor(std::ostream& os, const T* arr, Shape shape, siz
4040
os << str << std::string(padding_len, ' ') << (i==size-1 ? "" : ", ");
4141
}
4242
}
43-
//Print the truncated version
43+
//Print the truncated version. ex: {1, 1, 1, ... 1, 1, 1}
4444
else {
4545
//The first half
4646
for(intmax_t i=0;i<max_line_content/2;i++) {
@@ -164,7 +164,8 @@ bool Tensor::isSame(const Tensor& other) const
164164
if(shape() != other.shape())
165165
return false;
166166

167-
return (*this == other).sum().toHost<int32_t>()[0] == (int32_t)size();
167+
//A hacky comparsion
168+
return (*this == other).sum().item<int32_t>() == (int32_t)size();
168169
}
169170

170171
Tensor Tensor::view(svector<Range> ranges) const

Etaler/Core/Tensor.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,22 @@ struct ETALER_EXPORT Tensor
210210
std::shared_ptr<TensorImpl> pimpl_;
211211
};
212212

213-
static Tensor operator+ (std::variant<float, int, bool> v, const Tensor& t)
213+
static Tensor operator+ (std::variant<float, int, bool, half> v, const Tensor& t)
214214
{
215215
return std::visit([&t](auto v) {return Tensor(v)+t;}, v);
216216
}
217217

218-
static Tensor operator- (std::variant<float, int, bool> v, const Tensor& t)
218+
static Tensor operator- (std::variant<float, int, bool, half> v, const Tensor& t)
219219
{
220220
return std::visit([&t](auto v) {return Tensor(v)-t;}, v);
221221
}
222222

223-
static Tensor operator* (std::variant<float, int, bool> v, const Tensor& t)
223+
static Tensor operator* (std::variant<float, int, bool, half> v, const Tensor& t)
224224
{
225225
return std::visit([&t](auto v) {return Tensor(v)*t;}, v);
226226
}
227227

228-
static Tensor operator/ (std::variant<float, int, bool> v, const Tensor& t)
228+
static Tensor operator/ (std::variant<float, int, bool, half> v, const Tensor& t)
229229
{
230230
return std::visit([&t](auto v) {return Tensor(v)/t;}, v);
231231
}

Etaler/Interop/Xtensor.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ xt::xarray<T> to_xarray(const Tensor& t)
2121
auto shape = t.shape();
2222
std::vector<size_t> s(shape.begin(), shape.end());
2323
//Handle the case of bool
24-
using DataType = std::conditional<std::is_same<T, bool>, uint8_t, T>::type;
24+
using DataType = typename std::conditional<std::is_same_v<T, bool>, uint8_t, T>::type;
2525
auto vec = t.toHost<DataType>();
2626
return xt::adapt((const T*)vec.data(), s);
2727
}
2828

29-
}
29+
}

0 commit comments

Comments
 (0)