diff --git a/LanguageModel.lua b/LanguageModel.lua index d6248184..70b2f9d1 100644 --- a/LanguageModel.lua +++ b/LanguageModel.lua @@ -132,16 +132,8 @@ function LM:encode_string(s) return encoded end - -function LM:decode_string(encoded) - assert(torch.isTensor(encoded) and encoded:dim() == 1) - local s = '' - for i = 1, encoded:size(1) do - local idx = encoded[i] - local token = self.idx_to_token[idx] - s = s .. token - end - return s +function LM:decode_char(encoded) + return self.idx_to_token[encoded[1][1]] end @@ -162,8 +154,8 @@ function LM:sample(kwargs) local verbose = utils.get_kwarg(kwargs, 'verbose', 0) local sample = utils.get_kwarg(kwargs, 'sample', 1) local temperature = utils.get_kwarg(kwargs, 'temperature', 1) + local write = io.write - local sampled = torch.LongTensor(1, T) self:resetStates() local scores, first_t @@ -173,7 +165,7 @@ function LM:sample(kwargs) end local x = self:encode_string(start_text):view(1, -1) local T0 = x:size(2) - sampled[{{}, {1, T0}}]:copy(x) + write(start_text) scores = self:forward(x)[{{}, {T0, T0}}] first_t = T0 + 1 else @@ -195,12 +187,12 @@ function LM:sample(kwargs) probs:div(torch.sum(probs)) next_char = torch.multinomial(probs, 1):view(1, 1) end - sampled[{{}, {t, t}}]:copy(next_char) + write(self:decode_char(next_char)) scores = self:forward(next_char) end self:resetStates() - return self:decode_string(sampled[1]) + write('\n') end diff --git a/sample.lua b/sample.lua index 4e6ebae0..2e7548a0 100644 --- a/sample.lua +++ b/sample.lua @@ -38,5 +38,4 @@ if opt.verbose == 1 then print(msg) end model:evaluate() -local sample = model:sample(opt) -print(sample) +model:sample(opt)