Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions dev/todo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
- [x] run inference, save to output.txt
- `./build/bin/llama-simple -m models/gemma/gemma-1.1-7b-it.Q4_K_M.gguf -n 100 -p "Tell me about the history of artificial intelligence" >> output.txt`
New way:
```
./build/bin/llama-run --ngl 999 models/gemma/gemma-1.1-7b-it.Q4_K_M.gguf Hello World > output.txt
```

- [x] b) I want to modify the code, re-build project and see the changes
- Just something stupid. Print hello wordl from Petr
- changed `simple.cpp`
``` fprintf(stderr, "Generating token number %d\n", n_decode + 1); ```
Runs fine.

- [x] c) Next, I want specifically interject into places where RNGs are generated.
- During inference, sampling
- Specifically, save each rng generated number to a file

- [x] d) then I want to replace all the custom non-trivial rng generation
- (e.g. "sample this custom distribution") with my own implementations using the basic uniform (0,1) rng generator

- [ ] e) then I want to replace the default (0,1) uniform distribution with my custom provider coming from external api

- [ ] f) Idea: measure the bias in the source distribution based
- specifically: see each generated number and see how it changes

- [ ] g) try / support temperature > 1
- try min_p lower (0?)
- try other models - bigger, better
58 changes: 34 additions & 24 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,35 +129,45 @@ struct ring_buffer {
};

static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
// iterator for the probabilities
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#endif

struct probs_iterator {
typedef std::input_iterator_tag iterator_category;
typedef float value_type;
typedef float * pointer;
typedef float & reference;
typedef ptrdiff_t difference_type;
// Get uniform random number between 0 and 1
double u = std::uniform_real_distribution<>(0.0, 1.0)(rng);
fprintf(stderr, "\nRNG internal:\n");
fprintf(stderr, "- Raw uniform random number: %f\n", u);

const llama_token_data * data;
// Calculate cumulative probabilities
std::vector<float> cumulative_probs;
cumulative_probs.reserve(cur_p->size);
float sum = 0.0f;

bool operator==(const probs_iterator & other) const { return data == other.data; }
bool operator!=(const probs_iterator & other) const { return data != other.data; }
const float & operator*() const { return data->p; }
probs_iterator & operator++() { ++data; return *this; }
probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
};
fprintf(stderr, "- Token probabilities:\n");
for (size_t i = 0; i < cur_p->size; ++i) {
sum += cur_p->data[i].p;
cumulative_probs.push_back(sum);
fprintf(stderr, " [%zu] token %d = %f (cumulative: %f)\n",
i, cur_p->data[i].id, cur_p->data[i].p, sum);
}

#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
// Normalize cumulative probabilities
if (sum != 1.0f) {
for (float& p : cumulative_probs) {
p /= sum;
}
fprintf(stderr, "- Normalized cumulative probabilities\n");
}

std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
// Scale random number to probability sum
double scaled = u * 1.0; // since we normalized, multiply by 1.0
fprintf(stderr, "- Scaled random number: %f\n", scaled);

return dist(rng);
// Find the selected index using binary search
auto it = std::lower_bound(cumulative_probs.begin(), cumulative_probs.end(), scaled);
size_t selected_idx = it - cumulative_probs.begin();

fprintf(stderr, "- Selected index: %zu\n", selected_idx);
fprintf(stderr, "RNG generated sample: %zu (token id: %d, probability: %f)\n",
selected_idx, cur_p->data[selected_idx].id, cur_p->data[selected_idx].p);

return selected_idx;
}

/*
Expand Down