Skip to content

Suggest: Add Bayesian optimization support for ratio search #104

Open
trotsky1997 wants to merge 2 commits intomit-han-lab:mainfrom
trotsky1997:zhangdi
Open

Suggest: Add Bayesian optimization support for ratio search #104
trotsky1997 wants to merge 2 commits intomit-han-lab:mainfrom
trotsky1997:zhangdi

Conversation

@trotsky1997
Copy link

No description provided.

trotsky1997@qq.com added 2 commits October 26, 2023 20:40
@casper-hansen
Copy link
Contributor

Hi @trotsky1997, this looks very interesting! Have you conducted any experiments to measure perplexity after using Bayesian optimization?

@trotsky1997
Copy link
Author

Hi @trotsky1997, this looks very interesting! Have you conducted any experiments to measure perplexity after using Bayesian optimization?
You can check my result in
https://trotsky1997.notion.site/f49dcb79ab6245a7b689beed086e4c7b?pvs=4

@casper-hansen
Copy link
Contributor

@trotsky1997 does this code include different alpha value for X and W? You observed better perplexity with it.

@trotsky1997
Copy link
Author

@trotsky1997 does this code include different alpha value for X and W? You observed better perplexity with it.

that's very easy to modify, just add a new parameter called ratio_b to get_loss function, and replace 1-ratio with ratio_b, than define a new parameter ratio_b with its boundary in parameter definition.

@trotsky1997
Copy link
Author

        @scheduler.serial
        def get_loss(ratio,ratio_b):
            nonlocal best_error,best_ratio,best_scales
            ratio = ratio * 1 / n_grid
            scales = (x_max.pow(ratio) / w_max.pow(ratio_b)
                      ).clamp(min=1e-4).view(-1)
            scales = scales / (scales.max() * scales.min()).sqrt()
            for fc in linears2scale:
                fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
                fc.weight.data = w_quantize_func(
                    fc.weight.data) / (scales.view(1, -1))
            out = block(x, **kwargs)
            if isinstance(out, tuple):
                out = out[0]

            loss = (org_out - out).float().pow(2).mean().item()  # float prevents overflow
            history.append(loss)
            is_best = loss < best_error
            if is_best:
                best_error = loss
                best_ratio = ratio
                best_scales = scales
            block.load_state_dict(org_sd)
            return loss

        param_space = dict(ratio=uniform(0, 1),ratio_b=uniform(0, 1))

@trotsky1997
Copy link
Author

@trotsky1997 does this code include different alpha value for X and W? You observed better perplexity with it.

I have talked with Dr.Tang, it perform a little better than gs in vicuna, but just the same as gs in llama2-7b.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants