Skip to content

Conversation

@kaixuanliu
Copy link
Contributor

@kaixuanliu kaixuanliu commented Oct 17, 2025

When we run unit test like pytest -rA tests/pipelines/wan/test_wan_22.py::Wan22PipelineFastTests::test_save_load_float16, we found that the pipeline runs w/ all fp16 datatype, but after save and reload, some parts of text-encoder in pipe_loaded uses fp32, although we set torch_dtype to fp16 explicitly. Deep investigation found that the root cause is here: L783. Here we made an adjustment to the test case to manually add the component = component.to(torch_device).half() operation to align excatly with the behavior in pipe

@kaixuanliu
Copy link
Contributor Author

@a-r-r-o-w @DN6 pls help review, thx!

@regisss
Copy link
Contributor

regisss commented Oct 22, 2025

Not sure I understand the issue here. This specific T5 module is kept in fp32 on purpose, why forcing a fp16 cast in the test?

@kaixuanliu
Copy link
Contributor Author

kaixuanliu commented Oct 23, 2025

@regisss Hi, the purpose of this test case is to compare the output of pipelines using fp16 dtype(pipe) and the output of pipelines loaded from previously saved(pipe_loaded), they should be the same. However, all components of pipe is set to fp16 dtype in L1424~L1426, while for pipe_loaded, some parts are kept in fp32, which does not match exactly with the computation in pipe fwd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants