From 8ff9bfe02c8d0e4ce33da6e0134dc2b41531d29c Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 5 Oct 2024 16:45:52 +0900 Subject: [PATCH 01/29] Fix some documentation in ./src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 128 ++++++++++++++++---------------- 1 file changed, 63 insertions(+), 65 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 0f4b2ec03371..4102269c82fe 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -30,11 +30,11 @@ class MultiAdapter(ModelMixin): MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to user-assigned weighting. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as + downloading or saving. - Parameters: - adapters (`List[T2IAdapter]`, *optional*, defaults to None): + Args: + adapters (`List[T2IAdapter]`, optional, defaults=None): A list of `T2IAdapter` model instances. """ @@ -77,11 +77,14 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non r""" Args: xs (`torch.Tensor`): - (batch, channel, height, width) input images for multiple adapter models concated along dimension 1, - `channel` should equal to `num_adapter` * "number of channel of image". - adapter_weights (`List[float]`, *optional*, defaults to None): - List of floats representing the weight which will be multiply to each adapter's output before adding + A tensor of shape (batch, channel, height, width) representing input images for multiple adapter models, + concatenated along dimension 1(channel dimension). + The `channel` dimension should be equal to `num_adapter` * number of channel per image. + + adapter_weights (`List[float]`, optional, defaults=None): + A list of floats representing the weight which will be multiplied to each adapter's output before summing them together. + If `None`, equal weights will be used for all adapters. """ if adapter_weights is None: adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) @@ -109,24 +112,24 @@ def save_pretrained( variant: Optional[str] = None, ): """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the + Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the `[`~models.adapter.MultiAdapter.from_pretrained`]` class method. - Arguments: + Args: save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. + The directory where the model will be saved. If the directory does not exist, it will be created. + is_main_process (`bool`, optional, defaults=True): + Indicates whether current process is the main process or not. + Useful when in distributed training (e.g., TPUs) and need to call this function on all processes. + In this case, set `is_main_process=True` only for the main process to avoid race conditions. save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. + Function used to save the state dictionary. Useful on distributed training (e.g., TPUs) when one + need to replace `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment variable. + safe_serialization (`bool`, optional, defaults=True): + If `True`, save the model using `safetensors`. + If `False`, save the model using the traditional PyTorch way (using `pickle`). + variant (`str`, optional): + If specified, weights are saved in the format `pytorch_model..bin`. """ idx = 0 model_path_to_save = save_directory @@ -145,28 +148,25 @@ def save_pretrained( @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" - Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models. + Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. + the model, set it back to training mode using `model.train()`. - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. + Warnings: + *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. + *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded. - Parameters: + Args: pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`. - torch_dtype (`str` or `torch.dtype`, *optional*): + torch_dtype (`str` or `torch.dtype`, optional): Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. - output_loading_info(`bool`, *optional*, defaults to `False`): + output_loading_info(`bool`, optional, defaults=False): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, optional): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. @@ -174,21 +174,21 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + max_memory (`Dict`, optional): + A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory available for each GPU and the available CPU RAM if unset. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + low_cpu_mem_usage (`bool`, optional, defaults=True if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. - variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + variant (`str`, optional): + If specified, load weights from a `variant` file (*e.g.* pytorch_model..bin). `variant` will be ignored when using `from_flax`. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the - `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from - `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + use_safetensors (`bool`, optional, defaults=None): + If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is installed. + If `True`, the model will be forcibly loaded from`safetensors` weights. + If `False`, `safetensors` is not used. """ idx = 0 adapters = [] @@ -223,22 +223,20 @@ class T2IAdapter(ModelMixin, ConfigMixin): and [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (`int`, *optional*, defaults to 3): - Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale - image as *control image*. - channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will - also determine the number of downsample blocks in the Adapter. - num_res_blocks (`int`, *optional*, defaults to 2): + This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as downloading or saving. + + Args: + in_channels (`int`, optional, defaults=3): + Number of channels of Aapter's input(*control image*). Set to 1 if you're using gray scale image. + channels (`List[int]`, optional, defaults=[320, 640, 1280, 1280]): + Number of channel of each downsample block's output hidden state. The `len(block_out_channels)` determines + the number of downsample blocks in the Adapter. + num_res_blocks (`int`, optional, defaults=2): Number of ResNet blocks in each downsample block. - downscale_factor (`int`, *optional*, defaults to 8): + downscale_factor (`int`, optional, defaults=8): A factor that determines the total downscale factor of the Adapter. - adapter_type (`str`, *optional*, defaults to `full_adapter`): - The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. + adapter_type (`str`, optional, defaults=`full_adapter`): + Type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. """ @register_to_config @@ -393,15 +391,15 @@ class AdapterBlock(nn.Module): An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and `FullAdapterXL` models. - Parameters: + Args: in_channels (`int`): Number of channels of AdapterBlock's input. out_channels (`int`): Number of channels of AdapterBlock's output. num_res_blocks (`int`): Number of ResNet blocks in the AdapterBlock. - down (`bool`, *optional*, defaults to `False`): - Whether to perform downsampling on AdapterBlock's input. + down (`bool`, optional, defaults=False): + If `True`, perform downsampling on AdapterBlock's input. """ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): @@ -440,7 +438,7 @@ class AdapterResnetBlock(nn.Module): r""" An `AdapterResnetBlock` is a helper model that implements a ResNet-like block. - Parameters: + Args: channels (`int`): Number of channels of AdapterResnetBlock's input and output. """ @@ -518,15 +516,15 @@ class LightAdapterBlock(nn.Module): A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the `LightAdapter` model. - Parameters: + Args: in_channels (`int`): Number of channels of LightAdapterBlock's input. out_channels (`int`): Number of channels of LightAdapterBlock's output. num_res_blocks (`int`): Number of LightAdapterResnetBlocks in the LightAdapterBlock. - down (`bool`, *optional*, defaults to `False`): - Whether to perform downsampling on LightAdapterBlock's input. + down (`bool`, optional, defaults=False): + If `True`, perform downsampling on LightAdapterBlock's input. """ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): @@ -561,7 +559,7 @@ class LightAdapterResnetBlock(nn.Module): A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different architecture than `AdapterResnetBlock`. - Parameters: + Args: channels (`int`): Number of channels of LightAdapterResnetBlock's input and output. """ From 3aebab593534f02a74ec0106745419f7c59d5c1d Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:21:36 +0900 Subject: [PATCH 02/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 4102269c82fe..bd1b3815df04 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -34,7 +34,7 @@ class MultiAdapter(ModelMixin): downloading or saving. Args: - adapters (`List[T2IAdapter]`, optional, defaults=None): + adapters (`List[T2IAdapter]`, *optional*, defaults to None): A list of `T2IAdapter` model instances. """ From 42754341bb385419acc70a009f18be3f3c643805 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:21:42 +0900 Subject: [PATCH 03/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index bd1b3815df04..3cc36107f2b0 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -174,7 +174,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - max_memory (`Dict`, optional): + max_memory (`Dict`, *optional*): A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory available for each GPU and the available CPU RAM if unset. low_cpu_mem_usage (`bool`, optional, defaults=True if torch version >= 1.9.0 else `False`): From 5b820ee5e39165f4985af88ab47febfd339da00a Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:21:47 +0900 Subject: [PATCH 04/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 3cc36107f2b0..bb1a29714e8f 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -81,7 +81,7 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to `num_adapter` * number of channel per image. - adapter_weights (`List[float]`, optional, defaults=None): + adapter_weights (`List[float]`, *optional*, defaults to None): A list of floats representing the weight which will be multiplied to each adapter's output before summing them together. If `None`, equal weights will be used for all adapters. From 7868ba556b7e243484ce6244e1aafa4b322b482d Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:21:52 +0900 Subject: [PATCH 05/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index bb1a29714e8f..02682484c97f 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -161,7 +161,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`. - torch_dtype (`str` or `torch.dtype`, optional): + torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. output_loading_info(`bool`, optional, defaults=False): From 31b2ed9ef80f59b77f3df87ca867e0d1ebe7b5b1 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:21:59 +0900 Subject: [PATCH 06/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 02682484c97f..e32acc6ddc36 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -231,7 +231,7 @@ class T2IAdapter(ModelMixin, ConfigMixin): channels (`List[int]`, optional, defaults=[320, 640, 1280, 1280]): Number of channel of each downsample block's output hidden state. The `len(block_out_channels)` determines the number of downsample blocks in the Adapter. - num_res_blocks (`int`, optional, defaults=2): + num_res_blocks (`int`, *optional*, defaults to `2`): Number of ResNet blocks in each downsample block. downscale_factor (`int`, optional, defaults=8): A factor that determines the total downscale factor of the Adapter. From 73cee8ff372ee948a41b2457e34635fe2588a9e5 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:05 +0900 Subject: [PATCH 07/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index e32acc6ddc36..54f59972b574 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -166,7 +166,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] will be automatically derived from the model's weights. output_loading_info(`bool`, optional, defaults=False): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, optional): + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. From 4f9948df86724f66c8da0347483055775e12c47c Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:10 +0900 Subject: [PATCH 08/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 54f59972b574..8222d92bdea9 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -235,7 +235,7 @@ class T2IAdapter(ModelMixin, ConfigMixin): Number of ResNet blocks in each downsample block. downscale_factor (`int`, optional, defaults=8): A factor that determines the total downscale factor of the Adapter. - adapter_type (`str`, optional, defaults=`full_adapter`): + adapter_type (`str`, *optional*, defaults to `full_adapter`): Type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. """ From b5a4526b2fd1b80cb6229fa9ec6c9962696070aa Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:16 +0900 Subject: [PATCH 09/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 8222d92bdea9..8dc07456e1d1 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -233,7 +233,7 @@ class T2IAdapter(ModelMixin, ConfigMixin): the number of downsample blocks in the Adapter. num_res_blocks (`int`, *optional*, defaults to `2`): Number of ResNet blocks in each downsample block. - downscale_factor (`int`, optional, defaults=8): + downscale_factor (`int`, *optional*, defaults to `8`): A factor that determines the total downscale factor of the Adapter. adapter_type (`str`, *optional*, defaults to `full_adapter`): Type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. From b006b7809d65780d2bf0bf48dc7f6c73de1d2204 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:26 +0900 Subject: [PATCH 10/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 8dc07456e1d1..9eb75f496132 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -398,7 +398,7 @@ class AdapterBlock(nn.Module): Number of channels of AdapterBlock's output. num_res_blocks (`int`): Number of ResNet blocks in the AdapterBlock. - down (`bool`, optional, defaults=False): + down (`bool`, *optional*, defaults to `False`): If `True`, perform downsampling on AdapterBlock's input. """ From eceddbfbc46d3b20b258e8c3b50a22abe022e752 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:31 +0900 Subject: [PATCH 11/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 9eb75f496132..08d4e3b8a7cd 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -523,7 +523,7 @@ class LightAdapterBlock(nn.Module): Number of channels of LightAdapterBlock's output. num_res_blocks (`int`): Number of LightAdapterResnetBlocks in the LightAdapterBlock. - down (`bool`, optional, defaults=False): + down (`bool`, *optional*, defaults to `False`): If `True`, perform downsampling on LightAdapterBlock's input. """ From f0fd9756f71bb0f84a10ede06b60e7cbbabeab51 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:40 +0900 Subject: [PATCH 12/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 08d4e3b8a7cd..8cbb904349e8 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -164,7 +164,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. - output_loading_info(`bool`, optional, defaults=False): + output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be refined to each From bafb35ab8d5f6b2b6baacbf5a68cf87e80cdac60 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:48 +0900 Subject: [PATCH 13/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 8cbb904349e8..a9fd34ce719c 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -128,7 +128,7 @@ def save_pretrained( safe_serialization (`bool`, optional, defaults=True): If `True`, save the model using `safetensors`. If `False`, save the model using the traditional PyTorch way (using `pickle`). - variant (`str`, optional): + variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. """ idx = 0 From d41f1647237a19d4fc58e749184a42cf6689ba8e Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:22:56 +0900 Subject: [PATCH 14/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index a9fd34ce719c..91c4eb5dadc6 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -177,7 +177,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] max_memory (`Dict`, *optional*): A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory available for each GPU and the available CPU RAM if unset. - low_cpu_mem_usage (`bool`, optional, defaults=True if torch version >= 1.9.0 else `False`): + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, From b8f7e3a61cf3e47bc148125b08a6f3ad380f060f Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:23:33 +0900 Subject: [PATCH 15/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 91c4eb5dadc6..8141fda101a2 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -182,7 +182,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. - variant (`str`, optional): + variant (`str`, *optional*): If specified, load weights from a `variant` file (*e.g.* pytorch_model..bin). `variant` will be ignored when using `from_flax`. use_safetensors (`bool`, optional, defaults=None): From 38624a26f337459fd3f90a725484a095119d02ea Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:23:40 +0900 Subject: [PATCH 16/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 8141fda101a2..982b4c7c2843 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -226,7 +226,7 @@ class T2IAdapter(ModelMixin, ConfigMixin): This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as downloading or saving. Args: - in_channels (`int`, optional, defaults=3): + in_channels (`int`, *optional*, defaults to `3`): Number of channels of Aapter's input(*control image*). Set to 1 if you're using gray scale image. channels (`List[int]`, optional, defaults=[320, 640, 1280, 1280]): Number of channel of each downsample block's output hidden state. The `len(block_out_channels)` determines From 14ddde4983ae95d8115c6696e9e28e120872dd47 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:23:49 +0900 Subject: [PATCH 17/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 982b4c7c2843..2f617d8f7302 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -228,7 +228,7 @@ class T2IAdapter(ModelMixin, ConfigMixin): Args: in_channels (`int`, *optional*, defaults to `3`): Number of channels of Aapter's input(*control image*). Set to 1 if you're using gray scale image. - channels (`List[int]`, optional, defaults=[320, 640, 1280, 1280]): + channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): Number of channel of each downsample block's output hidden state. The `len(block_out_channels)` determines the number of downsample blocks in the Adapter. num_res_blocks (`int`, *optional*, defaults to `2`): From 18fa091720db90f919ac223599d6600cda43d6bb Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Mon, 7 Oct 2024 13:23:56 +0900 Subject: [PATCH 18/29] Update src/diffusers/models/adapter.py --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 2f617d8f7302..479336f4c7ba 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -185,7 +185,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] variant (`str`, *optional*): If specified, load weights from a `variant` file (*e.g.* pytorch_model..bin). `variant` will be ignored when using `from_flax`. - use_safetensors (`bool`, optional, defaults=None): + use_safetensors (`bool`, *optional*, defaults to `None`): If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`, `safetensors` is not used. From 5c2628bd38b819fbf31ad5013d7be438c054fe70 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 12 Oct 2024 15:01:00 +0900 Subject: [PATCH 19/29] Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 479336f4c7ba..8bdd90cbe6a4 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -229,8 +229,8 @@ class T2IAdapter(ModelMixin, ConfigMixin): in_channels (`int`, *optional*, defaults to `3`): Number of channels of Aapter's input(*control image*). Set to 1 if you're using gray scale image. channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - Number of channel of each downsample block's output hidden state. The `len(block_out_channels)` determines - the number of downsample blocks in the Adapter. + The number of channels in each downsample block's output hidden state. The `len(block_out_channels)` determines + the number of downsample blocks in the adapter. num_res_blocks (`int`, *optional*, defaults to `2`): Number of ResNet blocks in each downsample block. downscale_factor (`int`, *optional*, defaults to `8`): From 45c08ebe0a73f9631a9abce035b613aceec3f613 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 12 Oct 2024 15:01:42 +0900 Subject: [PATCH 20/29] Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 8bdd90cbe6a4..14413d6d9981 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -236,7 +236,7 @@ class T2IAdapter(ModelMixin, ConfigMixin): downscale_factor (`int`, *optional*, defaults to `8`): A factor that determines the total downscale factor of the Adapter. adapter_type (`str`, *optional*, defaults to `full_adapter`): - Type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. + Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use. """ @register_to_config From 2db9504129d97d4d12c9ea4b0e5352fac0e6fdab Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 12 Oct 2024 15:01:51 +0900 Subject: [PATCH 21/29] Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 14413d6d9981..a651b00f3aa5 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -127,7 +127,7 @@ def save_pretrained( need to replace `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment variable. safe_serialization (`bool`, optional, defaults=True): If `True`, save the model using `safetensors`. - If `False`, save the model using the traditional PyTorch way (using `pickle`). + If `False`, save the model with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. """ From 2687b96b2fa50052f9c3ff6cf7bc2e57683665c8 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 12 Oct 2024 15:02:11 +0900 Subject: [PATCH 22/29] Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/adapter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index a651b00f3aa5..40f56f14d578 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -123,8 +123,7 @@ def save_pretrained( Useful when in distributed training (e.g., TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only for the main process to avoid race conditions. save_function (`Callable`): - Function used to save the state dictionary. Useful on distributed training (e.g., TPUs) when one - need to replace `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment variable. + Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment variable. safe_serialization (`bool`, optional, defaults=True): If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`. From b88e450282da0d7f3a68fc91ef4392f199bb88c7 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 12 Oct 2024 15:02:22 +0900 Subject: [PATCH 23/29] Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 40f56f14d578..1c2cb6a12e3b 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -120,7 +120,7 @@ def save_pretrained( The directory where the model will be saved. If the directory does not exist, it will be created. is_main_process (`bool`, optional, defaults=True): Indicates whether current process is the main process or not. - Useful when in distributed training (e.g., TPUs) and need to call this function on all processes. + Useful for distributed training (e.g., TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only for the main process to avoid race conditions. save_function (`Callable`): Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment variable. From d988d7d3f16f6c63103206d6d0c3da7ce99a4f3f Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 12 Oct 2024 15:02:29 +0900 Subject: [PATCH 24/29] Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 1c2cb6a12e3b..1808b45d1ece 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -82,7 +82,7 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non The `channel` dimension should be equal to `num_adapter` * number of channel per image. adapter_weights (`List[float]`, *optional*, defaults to None): - A list of floats representing the weight which will be multiplied to each adapter's output before summing + A list of floats representing the weights which will be multiplied by each adapter's output before summing them together. If `None`, equal weights will be used for all adapters. """ From c8bbca0bff8d6bcef626f8233ef151116c296736 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Sat, 12 Oct 2024 15:02:36 +0900 Subject: [PATCH 25/29] Update src/diffusers/models/adapter.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/diffusers/models/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 1808b45d1ece..543713ee174e 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -226,7 +226,7 @@ class T2IAdapter(ModelMixin, ConfigMixin): Args: in_channels (`int`, *optional*, defaults to `3`): - Number of channels of Aapter's input(*control image*). Set to 1 if you're using gray scale image. + The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale image. channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The number of channels in each downsample block's output hidden state. The `len(block_out_channels)` determines the number of downsample blocks in the adapter. From fd3136074d4c70263873abe2fb4f1257a76a3c72 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Tue, 15 Oct 2024 13:41:53 +0000 Subject: [PATCH 26/29] run make style --- .../geodiff_molecule_conformation.ipynb | 7230 +++++++++-------- examples/research_projects/gligen/demo.ipynb | 13 +- 2 files changed, 3626 insertions(+), 3617 deletions(-) diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb index bde093802a5d..03f58f1f2f63 100644 --- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb +++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb @@ -1,3652 +1,3660 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "F88mignPnalS" - }, - "source": [ - "# Introduction\n", - "\n", - "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", - "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", - "\n", - "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", - "\n", - "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", - "\n", - "> Colab made by [natolambert](https://twitter.com/natolambert).\n", - "\n", - "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7cnwXMocnuzB" - }, - "source": [ - "## Installations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Install Conda" - ], - "metadata": { - "id": "ff9SxWnaNId9" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1g_6zOabItDk" - }, - "source": [ - "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K0ofXobG5Y-X", - "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2021 NVIDIA Corporation\n", - "Built on Sun_Feb_14_21:12:58_PST_2021\n", - "Cuda compilation tools, release 11.2, V11.2.152\n", - "Build cuda_11.2.r11.2/compiler.29618528_0\n" - ] - } - ], - "source": [ - "!nvcc --version" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfthW90vI0nw" - }, - "source": [ - "Install Conda for some more complex dependencies for geometric networks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2WNFzSnbiE0k", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install -q condacolab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NUsbWYCUI7Km" - }, - "source": [ - "Setup Conda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FZelreINdmd0", - "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✨🍰✨ Everything looks OK!\n" - ] - } - ], - "source": [ - "import condacolab\n", - "condacolab.install()" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" + }, + "source": [ + "## Installations\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ff9SxWnaNId9" + }, + "source": [ + "### Install Conda" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1g_6zOabItDk" + }, + "source": [ + "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K0ofXobG5Y-X", + "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2021 NVIDIA Corporation\n", + "Built on Sun_Feb_14_21:12:58_PST_2021\n", + "Cuda compilation tools, release 11.2, V11.2.152\n", + "Build cuda_11.2.r11.2/compiler.29618528_0\n" + ] + } + ], + "source": [ + "!nvcc --version" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VfthW90vI0nw" + }, + "source": [ + "Install Conda for some more complex dependencies for geometric networks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2WNFzSnbiE0k", + "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q condacolab" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" + }, + "source": [ + "Setup Conda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] + } + ], + "source": [ + "import condacolab\n", + "\n", + "\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" + }, + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QDS6FPZ0Tu5b" + }, + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dq1lxR10TtrR", + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] + } + ], + "source": [ + "!rm /usr/local/conda-meta/pinned" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D5ukfCOWfjzK", + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mgQA_XN-XGY2", + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001b[K\n", + "remote: Counting objects: 100% (40/40), done.\u001b[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers\n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" + }, + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 }, + "id": "gZt7BNi1e1PA", + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "JzDHaPU7I9Sn" - }, - "source": [ - "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMxRjHhL7w8V", - "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", - "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - cudatoolkit=11.1\n", - " - pytorch\n", - " - torchaudio\n", - " - torchvision\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 960 KB\n", - "\n", - "The following packages will be UPDATED:\n", - "\n", - " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", - "Preparing transaction: / \b\bdone\n", - "Verifying transaction: \\ \b\bdone\n", - "Executing transaction: / \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", - "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + "text/plain": [ + "'1.8.2'" ] - }, - { - "cell_type": "markdown", - "source": [ - "Need to remove a pathspec for colab that specifies the incorrect cuda version." - ], - "metadata": { - "id": "QDS6FPZ0Tu5b" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0CPv_NvehRz3", + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "jcl8GCS2mz6t", + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + }, + { + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } } - }, - { - "cell_type": "code", - "source": [ - "!rm /usr/local/conda-meta/pinned" - ], - "metadata": { - "id": "dq1lxR10TtrR", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z1L3DdZOJB30" - }, - "source": [ - "Install torch geometric (used in the model later)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D5ukfCOWfjzK", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - pytorch-geometric=1.7.2\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " decorator-4.4.2 | py_0 11 KB conda-forge\n", - " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", - " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", - " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", - " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", - " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", - " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", - " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", - " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", - " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", - " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", - " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", - " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", - " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", - " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", - " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", - " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", - " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", - " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", - " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 55.9 MB\n", - "\n", - "The following NEW packages will be INSTALLED:\n", - "\n", - " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", - " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", - " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", - " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", - " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", - " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", - " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", - " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", - " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", - " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", - " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", - " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", - " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", - " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", - " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", - " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", - " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", - " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", - " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", - "\n", - "The following packages will be DOWNGRADED:\n", - "\n", - " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", - "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", - "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", - "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", - "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", - "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", - "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", - "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", - "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", - "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", - "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", - "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", - "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", - "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", - "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", - "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", - "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", - "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", - "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", - "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", - "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", - "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install -c rusty1s pytorch-geometric=1.7.2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppxv6Mdkalbc" - }, - "source": [ - "### Install Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mgQA_XN-XGY2", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8t8_e_uVLdKB" + }, + "source": [ + "## Create a diffusion model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G0rMncVtNSqU" + }, + "source": [ + "### Model class(es)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L5FEXz5oXkzt" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "outputs": [], + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EzJQXPN_XrMX" + }, + "source": [ + "Helper classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oR1Y56QiLY90" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.Tensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QWrHJFcYXyUB" + }, + "source": [ + "Main model class!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "outputs": [], + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" + }, + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "id": "DyCo0nsqjbml", + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d90f304e9560472eacfbdd11e46765eb", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content\n", - "Cloning into 'diffusers'...\n", - "remote: Enumerating objects: 9298, done.\u001b[K\n", - "remote: Counting objects: 100% (40/40), done.\u001b[K\n", - "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", - "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", - "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", - "Resolving deltas: 100% (6168/6168), done.\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "%cd /content\n", - "\n", - "# install latest HF diffusers (will update to the release once added)\n", - "!git clone https://github.com/huggingface/diffusers.git\n", - "!pip install -q /content/diffusers\n", - "\n", - "# dependencies for diffusers\n", - "!pip install -q datasets transformers" + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load('/content/molecules.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "JVjz6iH_H6Eh", + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gZt7BNi1e1PA", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 - }, - "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "True\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'1.8.2'" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 8 - } - ], - "source": [ - "import torch\n", - "print(torch.cuda.is_available())\n", - "torch.__version__" + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLE7CqlfJNUO" - }, - "source": [ - "### Install Chemistry-specific Dependencies\n", - "\n", - "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vHNiZAUxNgoy" + }, + "source": [ + "## Run the diffusion process" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "import copy\n", + "import os\n", + "\n", + "from torch_geometric.data import Batch, Data\n", + "from torch_scatter import scatter_mean\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = 'cuda'\n", + "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 #0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None\n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = '/content/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "x9xuLUNg26z1", + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)):\n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input['pos_ref'] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + "\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", + "\n", + " with open(save_path, 'wb') as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fSApwSaZNndW" + }, + "source": [ + "## Render the results!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "\n", + "\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RjaVuR15NqzF" + }, + "source": [ + "### Helper functions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KieVE1vc0_Vs", + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "collect 5 generated molecules in `mols`\n" + ] + } + ], + "source": [ + "# the model can generate multiple conformations per 2d geometry\n", + "num_gen = results[0]['pos_gen'].shape[0]\n", + "\n", + "# init storage objects\n", + "mols_gen = []\n", + "mols_orig = []\n", + "for to_process in results:\n", + "\n", + " # store the reference 3d position\n", + " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # store the generated 3d position\n", + " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # copy data to new object\n", + " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", + "\n", + " # append results\n", + " mols_gen.append(new_mol)\n", + " mols_orig.append(to_process.rdmol)\n", + "\n", + "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tin89JwMKp4v" + }, + "source": [ + "Import tools to visualize the 2d chemical diagram of the molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yqV6gllSZn38" + }, + "outputs": [], + "source": [ + "from IPython.display import SVG, display\n", + "from rdkit import Chem\n", + "from rdkit.Chem.Draw import rdMolDraw2D as MD2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFNKmGddVoOk" + }, + "source": [ + "Select molecule to visualize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KzuwLlrrVaGc" + }, + "outputs": [], + "source": [ + "idx = 0\n", + "assert idx < len(results), \"selected molecule that was not generated\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hkb8w0_SNtU8" + }, + "source": [ + "### Viewing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3R4QBQeKttN" + }, + "source": [ + "This 2D rendering is the equivalent of the **input to the model**!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 321 + }, + "id": "gkQRWjraaKex", + "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" + }, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "text/plain": [ + "" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0CPv_NvehRz3", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", + "molSize=(450,300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace('svg:','')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" + }, + "source": [ + "Generate the 3d molecule!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "id": "aT1Bkb8YxJfV", + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "695ab5bbf30a4ab19df1f9f33469f314", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting rdkit\n", - " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", - "Installing collected packages: rdkit\n", - "Successfully installed rdkit-2022.3.5\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] + "text/plain": [] + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" } - ], - "source": [ - "!pip install rdkit" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88GaDbDPxJ5I" + } + } + }, + "output_type": "display_data" + } + ], + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "id": "pxtq8I-I18C-", + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be446195da2b4ff2aec21ec5ff963a54", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "### Get viewer from nglview\n", - "\n", - "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", - "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", - "The rdmol in this object is a source of ground truth for the generated molecules.\n", - "\n", - "You will use one rendering function from nglviewer later!\n", - "\n" + "text/plain": [ + "NGLWidget()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jcl8GCS2mz6t", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting nglview\n", - " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", - "Collecting jupyterlab-widgets\n", - " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipywidgets>=7\n", - " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting widgetsnbextension~=4.0\n", - " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipython>=6.1.0\n", - " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipykernel>=4.5.1\n", - " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting traitlets>=4.3.1\n", - " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", - "Collecting pyzmq>=17\n", - " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting matplotlib-inline>=0.1\n", - " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", - "Collecting tornado>=6.1\n", - " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting nest-asyncio\n", - " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", - "Collecting debugpy>=1.0\n", - " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting psutil\n", - " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jupyter-client>=6.1.12\n", - " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pickleshare\n", - " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", - "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", - "Collecting backcall\n", - " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pexpect>4.3\n", - " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pygments\n", - " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jedi>=0.16\n", - " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", - " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", - "Collecting parso<0.9.0,>=0.8.0\n", - " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", - "Collecting entrypoints\n", - " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", - "Collecting jupyter-core>=4.9.2\n", - " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ptyprocess>=0.5\n", - " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", - "Collecting wcwidth\n", - " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", - "Building wheels for collected packages: nglview\n", - " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", - " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", - "Successfully built nglview\n", - "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", - "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "pexpect", - "pickleshare", - "wcwidth" - ] - } - } - }, - "metadata": {} + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" } - ], - "source": [ - "!pip install nglview" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Create a diffusion model" - ], - "metadata": { - "id": "8t8_e_uVLdKB" + } } - }, - { - "cell_type": "markdown", - "source": [ - "### Model class(es)" - ], - "metadata": { - "id": "G0rMncVtNSqU" - } - }, - { - "cell_type": "markdown", - "source": [ - "Imports" - ], - "metadata": { - "id": "L5FEXz5oXkzt" - } - }, - { - "cell_type": "code", - "source": [ - "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", - "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", - "from dataclasses import dataclass\n", - "from typing import Callable, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch import Tensor, nn\n", - "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", - "\n", - "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", - "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", - "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", - "from torch_scatter import scatter_add\n", - "from torch_sparse import SparseTensor, coalesce\n", - "\n", - "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", - "from diffusers.modeling_utils import ModelMixin\n", - "from diffusers.utils import BaseOutput\n" - ], - "metadata": { - "id": "-3-P4w5sXkRU" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Helper classes" - ], - "metadata": { - "id": "EzJQXPN_XrMX" - } - }, - { - "cell_type": "code", - "source": [ - "@dataclass\n", - "class MoleculeGNNOutput(BaseOutput):\n", - " \"\"\"\n", - " Args:\n", - " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", - " Hidden states output. Output of last layer of model.\n", - " \"\"\"\n", - "\n", - " sample: torch.Tensor\n", - "\n", - "\n", - "class MultiLayerPerceptron(nn.Module):\n", - " \"\"\"\n", - " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", - " Args:\n", - " input_dim (int): input dimension\n", - " hidden_dim (list of int): hidden dimensions\n", - " activation (str or function, optional): activation function\n", - " dropout (float, optional): dropout rate\n", - " \"\"\"\n", - "\n", - " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", - " super(MultiLayerPerceptron, self).__init__()\n", - "\n", - " self.dims = [input_dim] + hidden_dims\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", - " self.activation = None\n", - " if dropout > 0:\n", - " self.dropout = nn.Dropout(dropout)\n", - " else:\n", - " self.dropout = None\n", - "\n", - " self.layers = nn.ModuleList()\n", - " for i in range(len(self.dims) - 1):\n", - " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\"\"\"\n", - " for i, layer in enumerate(self.layers):\n", - " x = layer(x)\n", - " if i < len(self.layers) - 1:\n", - " if self.activation:\n", - " x = self.activation(x)\n", - " if self.dropout:\n", - " x = self.dropout(x)\n", - " return x\n", - "\n", - "\n", - "class ShiftedSoftplus(torch.nn.Module):\n", - " def __init__(self):\n", - " super(ShiftedSoftplus, self).__init__()\n", - " self.shift = torch.log(torch.tensor(2.0)).item()\n", - "\n", - " def forward(self, x):\n", - " return F.softplus(x) - self.shift\n", - "\n", - "\n", - "class CFConv(MessagePassing):\n", - " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", - " super(CFConv, self).__init__(aggr=\"add\")\n", - " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", - " self.lin2 = Linear(num_filters, out_channels)\n", - " self.nn = mlp\n", - " self.cutoff = cutoff\n", - " self.smooth = smooth\n", - "\n", - " self.reset_parameters()\n", - "\n", - " def reset_parameters(self):\n", - " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", - " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", - " self.lin2.bias.data.fill_(0)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " if self.smooth:\n", - " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", - " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", - " else:\n", - " C = (edge_length <= self.cutoff).float()\n", - " W = self.nn(edge_attr) * C.view(-1, 1)\n", - "\n", - " x = self.lin1(x)\n", - " x = self.propagate(edge_index, x=x, W=W)\n", - " x = self.lin2(x)\n", - " return x\n", - "\n", - " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", - " return x_j * W\n", - "\n", - "\n", - "class InteractionBlock(torch.nn.Module):\n", - " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", - " super(InteractionBlock, self).__init__()\n", - " mlp = Sequential(\n", - " Linear(num_gaussians, num_filters),\n", - " ShiftedSoftplus(),\n", - " Linear(num_filters, num_filters),\n", - " )\n", - " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", - " self.act = ShiftedSoftplus()\n", - " self.lin = Linear(hidden_channels, hidden_channels)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " x = self.conv(x, edge_index, edge_length, edge_attr)\n", - " x = self.act(x)\n", - " x = self.lin(x)\n", - " return x\n", - "\n", - "\n", - "class SchNetEncoder(Module):\n", - " def __init__(\n", - " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.hidden_channels = hidden_channels\n", - " self.num_filters = num_filters\n", - " self.num_interactions = num_interactions\n", - " self.cutoff = cutoff\n", - "\n", - " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", - "\n", - " self.interactions = ModuleList()\n", - " for _ in range(num_interactions):\n", - " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", - " self.interactions.append(block)\n", - "\n", - " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", - " if embed_node:\n", - " assert z.dim() == 1 and z.dtype == torch.long\n", - " h = self.embedding(z)\n", - " else:\n", - " h = z\n", - " for interaction in self.interactions:\n", - " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", - "\n", - " return h\n", - "\n", - "\n", - "class GINEConv(MessagePassing):\n", - " \"\"\"\n", - " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", - " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", - " \"\"\"\n", - "\n", - " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", - " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", - " self.nn = mlp\n", - " self.initial_eps = eps\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " if train_eps:\n", - " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", - " else:\n", - " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", - "\n", - " def forward(\n", - " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", - " ) -> torch.Tensor:\n", - " \"\"\"\"\"\"\n", - " if isinstance(x, torch.Tensor):\n", - " x: OptPairTensor = (x, x)\n", - "\n", - " # Node and edge feature dimensionalites need to match.\n", - " if isinstance(edge_index, torch.Tensor):\n", - " assert edge_attr is not None\n", - " assert x[0].size(-1) == edge_attr.size(-1)\n", - " elif isinstance(edge_index, SparseTensor):\n", - " assert x[0].size(-1) == edge_index.size(-1)\n", - "\n", - " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", - " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", - "\n", - " x_r = x[1]\n", - " if x_r is not None:\n", - " out += (1 + self.eps) * x_r\n", - "\n", - " return self.nn(out)\n", - "\n", - " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", - " if self.activation:\n", - " return self.activation(x_j + edge_attr)\n", - " else:\n", - " return x_j + edge_attr\n", - "\n", - " def __repr__(self):\n", - " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", - "\n", - "\n", - "class GINEncoder(torch.nn.Module):\n", - " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", - " super().__init__()\n", - "\n", - " self.hidden_dim = hidden_dim\n", - " self.num_convs = num_convs\n", - " self.short_cut = short_cut\n", - " self.concat_hidden = concat_hidden\n", - " self.node_emb = nn.Embedding(100, hidden_dim)\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " self.convs = nn.ModuleList()\n", - " for i in range(self.num_convs):\n", - " self.convs.append(\n", - " GINEConv(\n", - " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", - " activation=activation,\n", - " )\n", - " )\n", - "\n", - " def forward(self, z, edge_index, edge_attr):\n", - " \"\"\"\n", - " Input:\n", - " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", - " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", - " Output:\n", - " node_feature: graph feature\n", - " \"\"\"\n", - "\n", - " node_attr = self.node_emb(z) # (num_node, hidden)\n", - "\n", - " hiddens = []\n", - " conv_input = node_attr # (num_node, hidden)\n", - "\n", - " for conv_idx, conv in enumerate(self.convs):\n", - " hidden = conv(conv_input, edge_index, edge_attr)\n", - " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", - " hidden = self.activation(hidden)\n", - " assert hidden.shape == conv_input.shape\n", - " if self.short_cut and hidden.shape == conv_input.shape:\n", - " hidden += conv_input\n", - "\n", - " hiddens.append(hidden)\n", - " conv_input = hidden\n", - "\n", - " if self.concat_hidden:\n", - " node_feature = torch.cat(hiddens, dim=-1)\n", - " else:\n", - " node_feature = hiddens[-1]\n", - "\n", - " return node_feature\n", - "\n", - "\n", - "class MLPEdgeEncoder(Module):\n", - " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", - " super().__init__()\n", - " self.hidden_dim = hidden_dim\n", - " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", - " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", - "\n", - " @property\n", - " def out_channels(self):\n", - " return self.hidden_dim\n", - "\n", - " def forward(self, edge_length, edge_type):\n", - " \"\"\"\n", - " Input:\n", - " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", - " Returns:\n", - " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", - " \"\"\"\n", - " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", - " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", - " return d_emb * edge_attr # (num_edge, hidden)\n", - "\n", - "\n", - "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", - " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", - " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", - " return h_pair\n", - "\n", - "\n", - "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", - " \"\"\"\n", - " Args:\n", - " num_nodes: Number of atoms.\n", - " edge_index: Bond indices of the original graph.\n", - " edge_type: Bond types of the original graph.\n", - " order: Extension order.\n", - " Returns:\n", - " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", - " \"\"\"\n", - "\n", - " def binarize(x):\n", - " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", - "\n", - " def get_higher_order_adj_matrix(adj, order):\n", - " \"\"\"\n", - " Args:\n", - " adj: (N, N)\n", - " type_mat: (N, N)\n", - " Returns:\n", - " Following attributes will be updated:\n", - " - edge_index\n", - " - edge_type\n", - " Following attributes will be added to the data object:\n", - " - bond_edge_index: Original edge_index.\n", - " \"\"\"\n", - " adj_mats = [\n", - " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", - " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", - " ]\n", - "\n", - " for i in range(2, order + 1):\n", - " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", - " order_mat = torch.zeros_like(adj)\n", - "\n", - " for i in range(1, order + 1):\n", - " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", - "\n", - " return order_mat\n", - "\n", - " num_types = 22\n", - " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", - " # from rdkit.Chem.rdchem import BondType as BT\n", - " N = num_nodes\n", - " adj = to_dense_adj(edge_index).squeeze(0)\n", - " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", - "\n", - " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", - " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", - " assert (type_mat * type_highorder == 0).all()\n", - " type_new = type_mat + type_highorder\n", - "\n", - " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", - " _, edge_order = dense_to_sparse(adj_order)\n", - "\n", - " # data.bond_edge_index = data.edge_index # Save original edges\n", - " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", - " assert edge_type.dim() == 1\n", - " N = pos.size(0)\n", - "\n", - " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", - "\n", - " if is_sidechain is None:\n", - " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", - " else:\n", - " # fetch sidechain and its batch index\n", - " is_sidechain = is_sidechain.bool()\n", - " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", - " sidechain_pos = pos[is_sidechain]\n", - " sidechain_index = dummy_index[is_sidechain]\n", - " sidechain_batch = batch[is_sidechain]\n", - "\n", - " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", - " r_edge_index_x = assign_index[1]\n", - " r_edge_index_y = assign_index[0]\n", - " r_edge_index_y = sidechain_index[r_edge_index_y]\n", - "\n", - " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", - " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", - " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", - " # delete self loop\n", - " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", - "\n", - " rgraph_adj = torch.sparse.LongTensor(\n", - " rgraph_edge_index,\n", - " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", - " torch.Size([N, N]),\n", - " )\n", - "\n", - " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", - "\n", - " new_edge_index = composed_adj.indices()\n", - " new_edge_type = composed_adj.values().long()\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def extend_graph_order_radius(\n", - " num_nodes,\n", - " pos,\n", - " edge_index,\n", - " edge_type,\n", - " batch,\n", - " order=3,\n", - " cutoff=10.0,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - "):\n", - " if extend_order:\n", - " edge_index, edge_type = _extend_graph_order(\n", - " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", - " )\n", - "\n", - " if extend_radius:\n", - " edge_index, edge_type = _extend_to_radius_graph(\n", - " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", - " )\n", - "\n", - " return edge_index, edge_type\n", - "\n", - "\n", - "def get_distance(pos, edge_index):\n", - " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", - "\n", - "\n", - "def graph_field_network(score_d, pos, edge_index, edge_length):\n", - " \"\"\"\n", - " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", - " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", - " \"\"\"\n", - " N = pos.size(0)\n", - " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", - " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", - " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", - " ) # (N, 3)\n", - " return score_pos\n", - "\n", - "\n", - "def clip_norm(vec, limit, p=2):\n", - " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", - " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", - " return vec * denom\n", - "\n", - "\n", - "def is_local_edge(edge_type):\n", - " return edge_type > 0\n" + }, + "output_type": "display_data" + } + ], + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" ], - "metadata": { - "id": "oR1Y56QiLY90" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Main model class!" + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "SliderStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "IntSliderModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" ], - "metadata": { - "id": "QWrHJFcYXyUB" - } - }, - { - "cell_type": "code", - "source": [ - "class MoleculeGNN(ModelMixin, ConfigMixin):\n", - " @register_to_config\n", - " def __init__(\n", - " self,\n", - " hidden_dim=128,\n", - " num_convs=6,\n", - " num_convs_local=4,\n", - " cutoff=10.0,\n", - " mlp_act=\"relu\",\n", - " edge_order=3,\n", - " edge_encoder=\"mlp\",\n", - " smooth_conv=True,\n", - " ):\n", - " super().__init__()\n", - " self.cutoff = cutoff\n", - " self.edge_encoder = edge_encoder\n", - " self.edge_order = edge_order\n", - "\n", - " \"\"\"\n", - " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", - " in SchNetEncoder\n", - " \"\"\"\n", - " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - "\n", - " \"\"\"\n", - " The graph neural network that extracts node-wise features.\n", - " \"\"\"\n", - " self.encoder_global = SchNetEncoder(\n", - " hidden_channels=hidden_dim,\n", - " num_filters=hidden_dim,\n", - " num_interactions=num_convs,\n", - " edge_channels=self.edge_encoder_global.out_channels,\n", - " cutoff=cutoff,\n", - " smooth=smooth_conv,\n", - " )\n", - " self.encoder_local = GINEncoder(\n", - " hidden_dim=hidden_dim,\n", - " num_convs=num_convs_local,\n", - " )\n", - "\n", - " \"\"\"\n", - " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", - " gradients w.r.t. edge_length (out_dim = 1).\n", - " \"\"\"\n", - " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " \"\"\"\n", - " Incorporate parameters together\n", - " \"\"\"\n", - " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", - " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", - "\n", - " def _forward(\n", - " self,\n", - " atom_type,\n", - " pos,\n", - " bond_index,\n", - " bond_type,\n", - " batch,\n", - " time_step, # NOTE, model trained without timestep performed best\n", - " edge_index=None,\n", - " edge_type=None,\n", - " edge_length=None,\n", - " return_edges=False,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " atom_type: Types of atoms, (N, ).\n", - " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", - " bond_type: Bond types, (E, ).\n", - " batch: Node index to graph index, (N, ).\n", - " \"\"\"\n", - " N = atom_type.size(0)\n", - " if edge_index is None or edge_type is None or edge_length is None:\n", - " edge_index, edge_type = extend_graph_order_radius(\n", - " num_nodes=N,\n", - " pos=pos,\n", - " edge_index=bond_index,\n", - " edge_type=bond_type,\n", - " batch=batch,\n", - " order=self.edge_order,\n", - " cutoff=self.cutoff,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " is_sidechain=is_sidechain,\n", - " )\n", - " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", - " local_edge_mask = is_local_edge(edge_type) # (E, )\n", - "\n", - " # with the parameterization of NCSNv2\n", - " # DDPM loss implicit handle the noise variance scale conditioning\n", - " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", - "\n", - " # Encoding global\n", - " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - "\n", - " # Global\n", - " node_attr_global = self.encoder_global(\n", - " z=atom_type,\n", - " edge_index=edge_index,\n", - " edge_length=edge_length,\n", - " edge_attr=edge_attr_global,\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_global = assemble_atom_pair_feature(\n", - " node_attr=node_attr_global,\n", - " edge_index=edge_index,\n", - " edge_attr=edge_attr_global,\n", - " ) # (E_global, 2H)\n", - " # Invariant features of edges (radius graph, global)\n", - " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", - "\n", - " # Encoding local\n", - " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - " # edge_attr += temb_edge\n", - "\n", - " # Local\n", - " node_attr_local = self.encoder_local(\n", - " z=atom_type,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_local = assemble_atom_pair_feature(\n", - " node_attr=node_attr_local,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " ) # (E_local, 2H)\n", - "\n", - " # Invariant features of edges (bond graph, local)\n", - " if isinstance(sigma_edge, torch.Tensor):\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", - " 1.0 / sigma_edge[local_edge_mask]\n", - " ) # (E_local, 1)\n", - " else:\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", - "\n", - " if return_edges:\n", - " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", - " else:\n", - " return edge_inv_global, edge_inv_local\n", - "\n", - " def forward(\n", - " self,\n", - " sample,\n", - " timestep: Union[torch.Tensor, float, int],\n", - " return_dict: bool = True,\n", - " sigma=1.0,\n", - " global_start_sigma=0.5,\n", - " w_global=1.0,\n", - " extend_order=False,\n", - " extend_radius=True,\n", - " clip_local=None,\n", - " clip_global=1000.0,\n", - " ) -> Union[MoleculeGNNOutput, Tuple]:\n", - " r\"\"\"\n", - " Args:\n", - " sample: packed torch geometric object\n", - " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", - " return_dict (`bool`, *optional*, defaults to `True`):\n", - " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", - " Returns:\n", - " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", - " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", - " \"\"\"\n", - "\n", - " # unpack sample\n", - " atom_type = sample.atom_type\n", - " bond_index = sample.edge_index\n", - " bond_type = sample.edge_type\n", - " num_graphs = sample.num_graphs\n", - " pos = sample.pos\n", - "\n", - " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", - "\n", - " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", - " atom_type=atom_type,\n", - " pos=sample.pos,\n", - " bond_index=bond_index,\n", - " bond_type=bond_type,\n", - " batch=sample.batch,\n", - " time_step=timesteps,\n", - " return_edges=True,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " ) # (E_global, 1), (E_local, 1)\n", - "\n", - " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", - " node_eq_local = graph_field_network(\n", - " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", - " )\n", - " if clip_local is not None:\n", - " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", - "\n", - " # Global\n", - " if sigma < global_start_sigma:\n", - " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", - " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", - " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", - " else:\n", - " node_eq_global = 0\n", - "\n", - " # Sum\n", - " eps_pos = node_eq_local + node_eq_global * w_global\n", - "\n", - " if not return_dict:\n", - " return (-eps_pos,)\n", - "\n", - " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_module_version": "3.0.1", + "model_name": "ColormakerRegistryModel", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_module_version": "3.0.1", + "model_name": "NGLModel", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292775, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 ], - "metadata": { - "id": "MCeZA1qQXzoK" + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "backgroundColor": "white", + "cameraEyeSep": 0.3, + "cameraFov": 40, + "cameraType": "perspective", + "clipDist": 10, + "clipFar": 100, + "clipNear": 0, + "fogFar": 100, + "fogNear": 50, + "hoverTimeout": 0, + "impostor": true, + "lightColor": 14540253, + "lightIntensity": 1, + "mousePreset": "default", + "panSpeed": 1, + "quality": "medium", + "rotateSpeed": 2, + "sampleLevel": 0, + "tooltip": true, + "workerDefault": true, + "zoomSpeed": 1.2 }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CCIrPYSJj9wd" - }, - "source": [ - "### Load pretrained model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YdrAr6Ch--Ab" - }, - "source": [ - "#### Load a model\n", - "The model used is a design an\n", - "equivariant convolutional layer, named graph field network (GFN).\n", - "\n", - "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DyCo0nsqjbml", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 172, - "referenced_widgets": [ - "d90f304e9560472eacfbdd11e46765eb", - "1c6246f15b654f4daa11c9bcf997b78c", - "c2321b3bff6f490ca12040a20308f555", - "b7feb522161f4cf4b7cc7c1a078ff12d", - "e2d368556e494ae7ae4e2e992af2cd4f", - "bbef741e76ec41b7ab7187b487a383df", - "561f742d418d4721b0670cc8dd62e22c", - "872915dd1bb84f538c44e26badabafdd", - "d022575f1fa2446d891650897f187b4d", - "fdc393f3468c432aa0ada05e238a5436", - "2c9362906e4b40189f16d14aa9a348da", - "6010fc8daa7a44d5aec4b830ec2ebaa1", - "7e0bb1b8d65249d3974200686b193be2", - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "6526646be5ed415c84d1245b040e629b", - "24d31fc3576e43dd9f8301d2ef3a37ab", - "2918bfaadc8d4b1a9832522c40dfefb8", - "a4bfdca35cc54dae8812720f1b276a08", - "e4901541199b45c6a18824627692fc39", - "f915cf874246446595206221e900b2fe", - "a9e388f22a9742aaaf538e22575c9433", - "42f6c3db29d7484ba6b4f73590abd2f4" - ] + "_ngl_msg_archive": [ + { + "args": [ + { + "binary": false, + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "type": "blob" + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" }, - "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "target": "Stage", + "type": "call_method" + } + ], + "_ngl_original_stage_parameters": { + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "backgroundColor": "white", + "cameraEyeSep": 0.3, + "cameraFov": 40, + "cameraType": "perspective", + "clipDist": 10, + "clipFar": 100, + "clipNear": 0, + "fogFar": 100, + "fogNear": 50, + "hoverTimeout": 0, + "impostor": true, + "lightColor": 14540253, + "lightIntensity": 1, + "mousePreset": "default", + "panSpeed": 1, + "quality": "medium", + "rotateSpeed": 2, + "sampleLevel": 0, + "tooltip": true, + "workerDefault": true, + "zoomSpeed": 1.2 }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", - "\n", - "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", - "\n" - ] + "metalness": 0, + "multipleBond": "off", + "opacity": 1, + "openEnded": true, + "quality": "high", + "radialSegments": 20, + "radiusData": {}, + "radiusScale": 2, + "radiusSize": 0.15, + "radiusType": "size", + "roughness": 0.4, + "sele": "", + "side": "double", + "sphereDetail": 2, + "useInteriorColor": true, + "visible": true, + "wireframe": false + }, + "type": "ball+stick" } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "dataset = torch.load('/content/molecules.pkl')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QZcmy1EvKQRk" - }, - "source": [ - "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JVjz6iH_H6Eh", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" - ] + }, + "1": { + "0": { + "params": { + "aspectRatio": 1.5, + "assembly": "default", + "bondScale": 0.3, + "bondSpacing": 0.75, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "dataset[0]" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Run the diffusion process" - ], - "metadata": { - "id": "vHNiZAUxNgoy" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZ1KZrxKqENg" - }, - "source": [ - "#### Helper Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s240tYueqKKf" - }, - "outputs": [], - "source": [ - "from torch_geometric.data import Data, Batch\n", - "from torch_scatter import scatter_add, scatter_mean\n", - "from tqdm import tqdm\n", - "import copy\n", - "import os\n", - "\n", - "def repeat_data(data: Data, num_repeat) -> Batch:\n", - " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", - " return Batch.from_data_list(datas)\n", - "\n", - "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", - " datas = batch.to_data_list()\n", - " new_data = []\n", - " for i in range(num_repeat):\n", - " new_data += copy.deepcopy(datas)\n", - " return Batch.from_data_list(new_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMnQTk0eqT7Z" - }, - "source": [ - "#### Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WYGkzqgzrHmF" - }, - "outputs": [], - "source": [ - "num_samples = 1 # solutions per molecule\n", - "num_molecules = 3\n", - "\n", - "DEVICE = 'cuda'\n", - "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", - "# constants for inference\n", - "w_global = 0.5 #0,.3 for qm9\n", - "global_start_sigma = 0.5\n", - "eta = 1.0\n", - "clip_local = None\n", - "clip_pos = None\n", - "\n", - "# constands for data handling\n", - "save_traj = False\n", - "save_data = False\n", - "output_dir = '/content/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-xD5bJ3SqM7t" - }, - "source": [ - "#### Generate samples!\n", - "Note that the 3d representation of a molecule is referred to as the **conformation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "x9xuLUNg26z1", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " after removing the cwd from sys.path.\n", - "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" - ] - } - ], - "source": [ - "results = []\n", - "\n", - "# define sigmas\n", - "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", - "sigmas = sigmas.to(DEVICE)\n", - "\n", - "for count, data in enumerate(tqdm(dataset)):\n", - " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", - "\n", - " data_input = data.clone()\n", - " data_input['pos_ref'] = None\n", - " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", - "\n", - " # initial configuration\n", - " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", - "\n", - " # for logging animation of denoising\n", - " pos_traj = []\n", - " with torch.no_grad():\n", - "\n", - " # scale initial sample\n", - " pos = pos_init * sigmas[-1]\n", - " for t in scheduler.timesteps:\n", - " batch.pos = pos\n", - "\n", - " # generate geometry with model, then filter it\n", - " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", - "\n", - " # Update\n", - " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", - "\n", - " pos = reconstructed_pos\n", - "\n", - " if torch.isnan(pos).any():\n", - " print(\"NaN detected. Please restart.\")\n", - " raise FloatingPointError()\n", - "\n", - " # recenter graph of positions for next iteration\n", - " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", - "\n", - " # optional clipping\n", - " if clip_pos is not None:\n", - " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", - " pos_traj.append(pos.clone().cpu())\n", - "\n", - " pos_gen = pos.cpu()\n", - " if save_traj:\n", - " pos_gen_traj = pos_traj.cpu()\n", - " data.pos_gen = torch.stack(pos_gen_traj)\n", - " else:\n", - " data.pos_gen = pos_gen\n", - " results.append(data)\n", - "\n", - "\n", - "if save_data:\n", - " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", - "\n", - " with open(save_path, 'wb') as f:\n", - " pickle.dump(results, f)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Render the results!" - ], - "metadata": { - "id": "fSApwSaZNndW" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d47Zxo2OKdgZ" - }, - "source": [ - "This function allows us to render 3d in colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e9Cd0kCAv9b8" - }, - "outputs": [], - "source": [ - "from google.colab import output\n", - "output.enable_custom_widget_manager()" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Helper functions" - ], - "metadata": { - "id": "RjaVuR15NqzF" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "28rBYa9NKhlz" - }, - "source": [ - "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LKdKdwxcyTQ6" - }, - "outputs": [], - "source": [ - "from copy import deepcopy\n", - "def set_rdmol_positions(rdkit_mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " mol = deepcopy(rdkit_mol)\n", - " set_rdmol_positions_(mol, pos)\n", - " return mol\n", - "\n", - "def set_rdmol_positions_(mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " for i in range(pos.shape[0]):\n", - " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", - " return mol\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NuE10hcpKmzK" - }, - "source": [ - "Process the generated data to make it easy to view." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KieVE1vc0_Vs", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "collect 5 generated molecules in `mols`\n" - ] - } - ], - "source": [ - "# the model can generate multiple conformations per 2d geometry\n", - "num_gen = results[0]['pos_gen'].shape[0]\n", - "\n", - "# init storage objects\n", - "mols_gen = []\n", - "mols_orig = []\n", - "for to_process in results:\n", - "\n", - " # store the reference 3d position\n", - " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # store the generated 3d position\n", - " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # copy data to new object\n", - " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", - "\n", - " # append results\n", - " mols_gen.append(new_mol)\n", - " mols_orig.append(to_process.rdmol)\n", - "\n", - "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tin89JwMKp4v" - }, - "source": [ - "Import tools to visualize the 2d chemical diagram of the molecule." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yqV6gllSZn38" - }, - "outputs": [], - "source": [ - "from rdkit.Chem import AllChem\n", - "from rdkit import Chem\n", - "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", - "from IPython.display import SVG, display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TFNKmGddVoOk" - }, - "source": [ - "Select molecule to visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KzuwLlrrVaGc" - }, - "outputs": [], - "source": [ - "idx = 0\n", - "assert idx < len(results), \"selected molecule that was not generated\"" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Viewing" - ], - "metadata": { - "id": "hkb8w0_SNtU8" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I3R4QBQeKttN" - }, - "source": [ - "This 2D rendering is the equivalent of the **input to the model**!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gkQRWjraaKex", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 321 - }, - "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" + "clipNear": 0, + "clipRadius": 0, + "colorMode": "hcl", + "colorReverse": false, + "colorScale": "", + "colorScheme": "element", + "colorValue": 9474192, + "cylinderOnly": false, + "defaultAssembly": "", + "depthWrite": true, + "diffuse": 16777215, + "diffuseInterior": false, + "disableImpostor": false, + "disablePicking": false, + "flatShaded": false, + "interiorColor": 2236962, + "interiorDarkening": 0, + "lazy": false, + "lineOnly": false, + "linewidth": 2, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] }, - "metadata": {} + "metalness": 0, + "multipleBond": "off", + "opacity": 1, + "openEnded": true, + "quality": "high", + "radialSegments": 20, + "radiusData": {}, + "radiusScale": 2, + "radiusSize": 0.15, + "radiusType": "size", + "roughness": 0.4, + "sele": "", + "side": "double", + "sphereDetail": 2, + "useInteriorColor": true, + "visible": true, + "wireframe": false + }, + "type": "ball+stick" } - ], - "source": [ - "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", - "molSize=(450,300)\n", - "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", - "drawer.DrawMolecule(mc)\n", - "drawer.FinishDrawing()\n", - "svg = drawer.GetDrawingText()\n", - "display(SVG(svg.replace('svg:','')))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z4FDMYMxKw2I" + } }, - "source": [ - "Generate the 3d molecule!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aT1Bkb8YxJfV", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "695ab5bbf30a4ab19df1f9f33469f314", - "eac6a8dcdc9d4335a2e51031793ead29" - ] - }, - "outputId": "b98870ae-049d-4386-b676-166e9526bda2" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "695ab5bbf30a4ab19df1f9f33469f314" - } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } - } + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" ], - "source": [ - "from nglview import show_rdkit as show" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pxtq8I-I18C-", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "be446195da2b4ff2aec21ec5ff963a54", - "c6596896148b4a8a9c57963b67c7782f", - "2489b5e5648541fbbdceadb05632a050", - "01e0ba4e5da04914b4652b8d58565d7b", - "c30e6c2f3e2a44dbbb3d63bd519acaa4", - "f31c6e40e9b2466a9064a2669933ecd5", - "19308ccac642498ab8b58462e3f1b0bb", - "4a081cdc2ec3421ca79dd933b7e2b0c4", - "e5c0d75eb5e1447abd560c8f2c6017e1", - "5146907ef6764654ad7d598baebc8b58", - "144ec959b7604a2cabb5ca46ae5e5379", - "abce2a80e6304df3899109c6d6cac199", - "65195cb7a4134f4887e9dd19f3676462" - ] - }, - "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "NGLWidget()" - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "be446195da2b4ff2aec21ec5ff963a54" - } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } - } + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" ], - "source": [ - "# new molecule\n", - "show(mols_gen[idx])" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "KJr4h2mwXeTo" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "d90f304e9560472eacfbdd11e46765eb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", - "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", - "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" - ], - "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" - } - }, - "1c6246f15b654f4daa11c9bcf997b78c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", - "placeholder": "​", - "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", - "value": "Downloading: 100%" - } - }, - "c2321b3bff6f490ca12040a20308f555": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", - "max": 3271865, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", - "value": 3271865 - } - }, - "b7feb522161f4cf4b7cc7c1a078ff12d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", - "placeholder": "​", - "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", - "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" - } - }, - "e2d368556e494ae7ae4e2e992af2cd4f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "bbef741e76ec41b7ab7187b487a383df": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "561f742d418d4721b0670cc8dd62e22c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "872915dd1bb84f538c44e26badabafdd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d022575f1fa2446d891650897f187b4d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fdc393f3468c432aa0ada05e238a5436": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c9362906e4b40189f16d14aa9a348da": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6010fc8daa7a44d5aec4b830ec2ebaa1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", - "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "IPY_MODEL_6526646be5ed415c84d1245b040e629b" - ], - "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" - } - }, - "7e0bb1b8d65249d3974200686b193be2": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", - "placeholder": "​", - "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", - "value": "Downloading: 100%" - } - }, - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", - "max": 401, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", - "value": 401 - } - }, - "6526646be5ed415c84d1245b040e629b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", - "placeholder": "​", - "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", - "value": " 401/401 [00:00<00:00, 13.5kB/s]" - } - }, - "24d31fc3576e43dd9f8301d2ef3a37ab": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2918bfaadc8d4b1a9832522c40dfefb8": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a4bfdca35cc54dae8812720f1b276a08": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e4901541199b45c6a18824627692fc39": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f915cf874246446595206221e900b2fe": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "a9e388f22a9742aaaf538e22575c9433": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "42f6c3db29d7484ba6b4f73590abd2f4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "695ab5bbf30a4ab19df1f9f33469f314": { - "model_module": "nglview-js-widgets", - "model_name": "ColormakerRegistryModel", - "model_module_version": "3.0.1", - "state": { - "_dom_classes": [], - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "ColormakerRegistryModel", - "_msg_ar": [], - "_msg_q": [], - "_ready": false, - "_view_count": null, - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "ColormakerRegistryView", - "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" - } - }, - "eac6a8dcdc9d4335a2e51031793ead29": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be446195da2b4ff2aec21ec5ff963a54": { - "model_module": "nglview-js-widgets", - "model_name": "NGLModel", - "model_module_version": "3.0.1", - "state": { - "_camera_orientation": [ - -15.519693580202304, - -14.065056548036177, - -23.53197484807691, - 0, - -23.357853515109753, - 20.94055073042662, - 2.888695042134944, - 0, - 14.352363398292777, - 18.870825741878015, - -20.744689572909344, - 0, - 0.2724999189376831, - 0.6940000057220459, - -0.3734999895095825, - 1 - ], - "_camera_str": "orthographic", - "_dom_classes": [], - "_gui_theme": null, - "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", - "_igui": null, - "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "NGLModel", - "_ngl_color_dict": {}, - "_ngl_coordinate_resource": {}, - "_ngl_full_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_msg_archive": [ - { - "target": "Stage", - "type": "call_method", - "methodName": "loadFile", - "reconstruc_color_scheme": false, - "args": [ - { - "type": "blob", - "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", - "binary": false - } - ], - "kwargs": { - "defaultRepresentation": true, - "ext": "pdb" - } - } - ], - "_ngl_original_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_repr_dict": { - "0": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - }, - "1": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - } - }, - "_ngl_serialize": false, - "_ngl_version": "", - "_ngl_view_id": [ - "FB989FD1-5B9C-446B-8914-6B58AF85446D" - ], - "_player_dict": {}, - "_scene_position": {}, - "_scene_rotation": {}, - "_synced_model_ids": [], - "_synced_repr_model_ids": [], - "_view_count": null, - "_view_height": "", - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "NGLView", - "_view_width": "", - "background": "white", - "frame": 0, - "gui_style": null, - "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", - "max_frame": 0, - "n_components": 2, - "picked": {} - } - }, - "c6596896148b4a8a9c57963b67c7782f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2489b5e5648541fbbdceadb05632a050": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ButtonView", - "button_style": "", - "description": "", - "disabled": false, - "icon": "compress", - "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", - "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", - "tooltip": "" - } - }, - "01e0ba4e5da04914b4652b8d58565d7b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", - "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" - ], - "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" - } - }, - "c30e6c2f3e2a44dbbb3d63bd519acaa4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f31c6e40e9b2466a9064a2669933ecd5": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "19308ccac642498ab8b58462e3f1b0bb": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4a081cdc2ec3421ca79dd933b7e2b0c4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "SliderStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "", - "handle_color": null - } - }, - "e5c0d75eb5e1447abd560c8f2c6017e1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "PlayModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "PlayModel", - "_playing": false, - "_repeat": false, - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "PlayView", - "description": "", - "description_tooltip": null, - "disabled": false, - "interval": 100, - "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", - "max": 0, - "min": 0, - "show_repeat": true, - "step": 1, - "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", - "value": 0 - } - }, - "5146907ef6764654ad7d598baebc8b58": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntSliderModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "IntSliderModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "IntSliderView", - "continuous_update": true, - "description": "", - "description_tooltip": null, - "disabled": false, - "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", - "max": 0, - "min": 0, - "orientation": "horizontal", - "readout": true, - "readout_format": "d", - "step": 1, - "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", - "value": 0 - } - }, - "144ec959b7604a2cabb5ca46ae5e5379": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "abce2a80e6304df3899109c6d6cac199": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "34px" - } - }, - "65195cb7a4134f4887e9dd19f3676462": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "button_color": null, - "font_weight": "" - } - } - } + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "PlayModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } } - }, - "nbformat": 4, - "nbformat_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb index 571f1a0323a2..4930253ff66e 100644 --- a/examples/research_projects/gligen/demo.ipynb +++ b/examples/research_projects/gligen/demo.ipynb @@ -26,8 +26,7 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "import torch\n", - "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline" + "from diffusers import StableDiffusionGLIGENPipeline" ] }, { @@ -36,16 +35,17 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", + "from transformers import CLIPTextModel, CLIPTokenizer\n", + "\n", "import diffusers\n", "from diffusers import (\n", " AutoencoderKL,\n", " DDPMScheduler,\n", - " UNet2DConditionModel,\n", - " UniPCMultistepScheduler,\n", " EulerDiscreteScheduler,\n", + " UNet2DConditionModel,\n", ")\n", - "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", + "\n", + "\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "\n", "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", @@ -122,6 +122,7 @@ "\n", "import numpy as np\n", "\n", + "\n", "boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = boxes / 512\n", "boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\n", From 7f95ec11be86747b20f442128c02c141851ecd0a Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Tue, 15 Oct 2024 13:50:52 +0000 Subject: [PATCH 27/29] make style & fix --- src/diffusers/models/adapter.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index 543713ee174e..c253bcb79e43 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -30,7 +30,7 @@ class MultiAdapter(ModelMixin): MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to user-assigned weighting. - This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as + This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading or saving. Args: @@ -77,10 +77,10 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non r""" Args: xs (`torch.Tensor`): - A tensor of shape (batch, channel, height, width) representing input images for multiple adapter models, + A tensor of shape (batch, channel, height, width) representing input images for multiple adapter models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to `num_adapter` * number of channel per image. - + adapter_weights (`List[float]`, *optional*, defaults to None): A list of floats representing the weights which will be multiplied by each adapter's output before summing them together. @@ -119,8 +119,8 @@ def save_pretrained( save_directory (`str` or `os.PathLike`): The directory where the model will be saved. If the directory does not exist, it will be created. is_main_process (`bool`, optional, defaults=True): - Indicates whether current process is the main process or not. - Useful for distributed training (e.g., TPUs) and need to call this function on all processes. + Indicates whether current process is the main process or not. + Useful for distributed training (e.g., TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only for the main process to avoid race conditions. save_function (`Callable`): Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment variable. @@ -153,7 +153,7 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] the model, set it back to training mode using `model.train()`. Warnings: - *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. + *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded. Args: @@ -185,8 +185,8 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] If specified, load weights from a `variant` file (*e.g.* pytorch_model..bin). `variant` will be ignored when using `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): - If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is installed. - If `True`, the model will be forcibly loaded from`safetensors` weights. + If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is installed. + If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`, `safetensors` is not used. """ idx = 0 From 67a35b81bb68f07d3d674d1aff7ecd6b4ffd7e67 Mon Sep 17 00:00:00 2001 From: Ahnjj_DEV Date: Tue, 15 Oct 2024 15:00:24 +0000 Subject: [PATCH 28/29] make style : 0.1.5 version ruff --- src/diffusers/models/adapter.py | 57 +++++++++++++++++---------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index c253bcb79e43..677a991f055e 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -30,8 +30,8 @@ class MultiAdapter(ModelMixin): MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to user-assigned weighting. - This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as - downloading or saving. + This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading + or saving. Args: adapters (`List[T2IAdapter]`, *optional*, defaults to None): @@ -77,14 +77,13 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non r""" Args: xs (`torch.Tensor`): - A tensor of shape (batch, channel, height, width) representing input images for multiple adapter models, - concatenated along dimension 1(channel dimension). - The `channel` dimension should be equal to `num_adapter` * number of channel per image. + A tensor of shape (batch, channel, height, width) representing input images for multiple adapter + models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to + `num_adapter` * number of channel per image. adapter_weights (`List[float]`, *optional*, defaults to None): - A list of floats representing the weights which will be multiplied by each adapter's output before summing - them together. - If `None`, equal weights will be used for all adapters. + A list of floats representing the weights which will be multiplied by each adapter's output before + summing them together. If `None`, equal weights will be used for all adapters. """ if adapter_weights is None: adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) @@ -119,14 +118,15 @@ def save_pretrained( save_directory (`str` or `os.PathLike`): The directory where the model will be saved. If the directory does not exist, it will be created. is_main_process (`bool`, optional, defaults=True): - Indicates whether current process is the main process or not. - Useful for distributed training (e.g., TPUs) and need to call this function on all processes. - In this case, set `is_main_process=True` only for the main process to avoid race conditions. + Indicates whether current process is the main process or not. Useful for distributed training (e.g., + TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only + for the main process to avoid race conditions. save_function (`Callable`): - Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment variable. + Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace + `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment + variable. safe_serialization (`bool`, optional, defaults=True): - If `True`, save the model using `safetensors`. - If `False`, save the model with `pickle`. + If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. """ @@ -153,8 +153,9 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] the model, set it back to training mode using `model.train()`. Warnings: - *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. - *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded. + *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained + with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights + from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded. Args: pretrained_model_path (`os.PathLike`): @@ -174,20 +175,20 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike] more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): - A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory available for each - GPU and the available CPU RAM if unset. + A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory + available for each GPU and the available CPU RAM if unset. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. variant (`str`, *optional*): - If specified, load weights from a `variant` file (*e.g.* pytorch_model..bin). `variant` will be - ignored when using `from_flax`. + If specified, load weights from a `variant` file (*e.g.* pytorch_model..bin). `variant` will + be ignored when using `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): - If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is installed. - If `True`, the model will be forcibly loaded from`safetensors` weights. - If `False`, `safetensors` is not used. + If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is + installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`, + `safetensors` is not used. """ idx = 0 adapters = [] @@ -222,14 +223,16 @@ class T2IAdapter(ModelMixin, ConfigMixin): and [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). - This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as downloading or saving. + This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as + downloading or saving. Args: in_channels (`int`, *optional*, defaults to `3`): - The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale image. + The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale + image. channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The number of channels in each downsample block's output hidden state. The `len(block_out_channels)` determines - the number of downsample blocks in the adapter. + The number of channels in each downsample block's output hidden state. The `len(block_out_channels)` + determines the number of downsample blocks in the adapter. num_res_blocks (`int`, *optional*, defaults to `2`): Number of ResNet blocks in each downsample block. downscale_factor (`int`, *optional*, defaults to `8`): From 944ad894b2574c54a9dc7f5253868de7785b4040 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Oct 2024 17:24:04 +0200 Subject: [PATCH 29/29] revert changes to examples --- .../geodiff_molecule_conformation.ipynb | 7230 ++++++++--------- examples/research_projects/gligen/demo.ipynb | 13 +- 2 files changed, 3617 insertions(+), 3626 deletions(-) diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb index 03f58f1f2f63..bde093802a5d 100644 --- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb +++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb @@ -1,3660 +1,3652 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "F88mignPnalS" - }, - "source": [ - "# Introduction\n", - "\n", - "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", - "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", - "\n", - "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", - "\n", - "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", - "\n", - "> Colab made by [natolambert](https://twitter.com/natolambert).\n", - "\n", - "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7cnwXMocnuzB" - }, - "source": [ - "## Installations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ff9SxWnaNId9" - }, - "source": [ - "### Install Conda" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1g_6zOabItDk" - }, - "source": [ - "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K0ofXobG5Y-X", - "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2021 NVIDIA Corporation\n", - "Built on Sun_Feb_14_21:12:58_PST_2021\n", - "Cuda compilation tools, release 11.2, V11.2.152\n", - "Build cuda_11.2.r11.2/compiler.29618528_0\n" - ] - } - ], - "source": [ - "!nvcc --version" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfthW90vI0nw" - }, - "source": [ - "Install Conda for some more complex dependencies for geometric networks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2WNFzSnbiE0k", - "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install -q condacolab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NUsbWYCUI7Km" - }, - "source": [ - "Setup Conda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FZelreINdmd0", - "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "✨🍰✨ Everything looks OK!\n" - ] - } - ], - "source": [ - "import condacolab\n", - "\n", - "\n", - "condacolab.install()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JzDHaPU7I9Sn" - }, - "source": [ - "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMxRjHhL7w8V", - "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", - "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - cudatoolkit=11.1\n", - " - pytorch\n", - " - torchaudio\n", - " - torchvision\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 960 KB\n", - "\n", - "The following packages will be UPDATED:\n", - "\n", - " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", - "Preparing transaction: / \b\bdone\n", - "Verifying transaction: \\ \b\bdone\n", - "Executing transaction: / \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", - "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QDS6FPZ0Tu5b" - }, - "source": [ - "Need to remove a pathspec for colab that specifies the incorrect cuda version." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "dq1lxR10TtrR", - "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" - ] - } - ], - "source": [ - "!rm /usr/local/conda-meta/pinned" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z1L3DdZOJB30" - }, - "source": [ - "Install torch geometric (used in the model later)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "D5ukfCOWfjzK", - "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - pytorch-geometric=1.7.2\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " decorator-4.4.2 | py_0 11 KB conda-forge\n", - " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", - " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", - " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", - " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", - " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", - " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", - " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", - " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", - " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", - " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", - " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", - " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", - " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", - " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", - " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", - " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", - " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", - " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", - " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 55.9 MB\n", - "\n", - "The following NEW packages will be INSTALLED:\n", - "\n", - " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", - " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", - " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", - " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", - " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", - " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", - " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", - " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", - " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", - " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", - " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", - " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", - " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", - " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", - " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", - " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", - " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", - " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", - " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", - "\n", - "The following packages will be DOWNGRADED:\n", - "\n", - " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", - "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", - "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", - "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", - "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", - "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", - "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", - "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", - "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", - "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", - "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", - "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", - "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", - "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", - "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", - "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", - "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", - "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", - "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", - "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", - "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", - "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install -c rusty1s pytorch-geometric=1.7.2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppxv6Mdkalbc" - }, - "source": [ - "### Install Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mgQA_XN-XGY2", - "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/content\n", - "Cloning into 'diffusers'...\n", - "remote: Enumerating objects: 9298, done.\u001b[K\n", - "remote: Counting objects: 100% (40/40), done.\u001b[K\n", - "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", - "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", - "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", - "Resolving deltas: 100% (6168/6168), done.\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "%cd /content\n", - "\n", - "# install latest HF diffusers (will update to the release once added)\n", - "!git clone https://github.com/huggingface/diffusers.git\n", - "!pip install -q /content/diffusers\n", - "\n", - "# dependencies for diffusers\n", - "!pip install -q datasets transformers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LZO6AJKuJKO8" - }, - "source": [ - "Check that torch is installed correctly and utilizing the GPU in the colab" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 - }, - "id": "gZt7BNi1e1PA", - "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" - }, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] }, { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" }, - "text/plain": [ - "'1.8.2'" + "source": [ + "## Installations\n", + "\n" ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import torch\n", - "\n", - "\n", - "print(torch.cuda.is_available())\n", - "torch.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLE7CqlfJNUO" - }, - "source": [ - "### Install Chemistry-specific Dependencies\n", - "\n", - "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0CPv_NvehRz3", - "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting rdkit\n", - " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", - "Installing collected packages: rdkit\n", - "Successfully installed rdkit-2022.3.5\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install rdkit" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88GaDbDPxJ5I" - }, - "source": [ - "### Get viewer from nglview\n", - "\n", - "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", - "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", - "The rdmol in this object is a source of ground truth for the generated molecules.\n", - "\n", - "You will use one rendering function from nglviewer later!\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "jcl8GCS2mz6t", - "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting nglview\n", - " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", - "Collecting jupyterlab-widgets\n", - " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipywidgets>=7\n", - " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting widgetsnbextension~=4.0\n", - " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipython>=6.1.0\n", - " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipykernel>=4.5.1\n", - " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting traitlets>=4.3.1\n", - " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", - "Collecting pyzmq>=17\n", - " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting matplotlib-inline>=0.1\n", - " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", - "Collecting tornado>=6.1\n", - " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting nest-asyncio\n", - " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", - "Collecting debugpy>=1.0\n", - " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting psutil\n", - " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jupyter-client>=6.1.12\n", - " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pickleshare\n", - " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", - "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", - "Collecting backcall\n", - " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pexpect>4.3\n", - " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pygments\n", - " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jedi>=0.16\n", - " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", - " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", - "Collecting parso<0.9.0,>=0.8.0\n", - " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", - "Collecting entrypoints\n", - " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", - "Collecting jupyter-core>=4.9.2\n", - " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ptyprocess>=0.5\n", - " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", - "Collecting wcwidth\n", - " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", - "Building wheels for collected packages: nglview\n", - " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", - " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", - "Successfully built nglview\n", - "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", - "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - }, - { - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "pexpect", - "pickleshare", - "wcwidth" - ] - } + }, + { + "cell_type": "markdown", + "source": [ + "### Install Conda" + ], + "metadata": { + "id": "ff9SxWnaNId9" } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "!pip install nglview" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8t8_e_uVLdKB" - }, - "source": [ - "## Create a diffusion model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "G0rMncVtNSqU" - }, - "source": [ - "### Model class(es)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L5FEXz5oXkzt" - }, - "source": [ - "Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-3-P4w5sXkRU" - }, - "outputs": [], - "source": [ - "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", - "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", - "from dataclasses import dataclass\n", - "from typing import Callable, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch import Tensor, nn\n", - "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", - "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", - "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", - "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", - "from torch_scatter import scatter_add\n", - "from torch_sparse import SparseTensor, coalesce\n", - "\n", - "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", - "from diffusers.modeling_utils import ModelMixin\n", - "from diffusers.utils import BaseOutput\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EzJQXPN_XrMX" - }, - "source": [ - "Helper classes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oR1Y56QiLY90" - }, - "outputs": [], - "source": [ - "@dataclass\n", - "class MoleculeGNNOutput(BaseOutput):\n", - " \"\"\"\n", - " Args:\n", - " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", - " Hidden states output. Output of last layer of model.\n", - " \"\"\"\n", - "\n", - " sample: torch.Tensor\n", - "\n", - "\n", - "class MultiLayerPerceptron(nn.Module):\n", - " \"\"\"\n", - " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", - " Args:\n", - " input_dim (int): input dimension\n", - " hidden_dim (list of int): hidden dimensions\n", - " activation (str or function, optional): activation function\n", - " dropout (float, optional): dropout rate\n", - " \"\"\"\n", - "\n", - " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", - " super(MultiLayerPerceptron, self).__init__()\n", - "\n", - " self.dims = [input_dim] + hidden_dims\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", - " self.activation = None\n", - " if dropout > 0:\n", - " self.dropout = nn.Dropout(dropout)\n", - " else:\n", - " self.dropout = None\n", - "\n", - " self.layers = nn.ModuleList()\n", - " for i in range(len(self.dims) - 1):\n", - " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\"\"\"\n", - " for i, layer in enumerate(self.layers):\n", - " x = layer(x)\n", - " if i < len(self.layers) - 1:\n", - " if self.activation:\n", - " x = self.activation(x)\n", - " if self.dropout:\n", - " x = self.dropout(x)\n", - " return x\n", - "\n", - "\n", - "class ShiftedSoftplus(torch.nn.Module):\n", - " def __init__(self):\n", - " super(ShiftedSoftplus, self).__init__()\n", - " self.shift = torch.log(torch.tensor(2.0)).item()\n", - "\n", - " def forward(self, x):\n", - " return F.softplus(x) - self.shift\n", - "\n", - "\n", - "class CFConv(MessagePassing):\n", - " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", - " super(CFConv, self).__init__(aggr=\"add\")\n", - " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", - " self.lin2 = Linear(num_filters, out_channels)\n", - " self.nn = mlp\n", - " self.cutoff = cutoff\n", - " self.smooth = smooth\n", - "\n", - " self.reset_parameters()\n", - "\n", - " def reset_parameters(self):\n", - " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", - " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", - " self.lin2.bias.data.fill_(0)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " if self.smooth:\n", - " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", - " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", - " else:\n", - " C = (edge_length <= self.cutoff).float()\n", - " W = self.nn(edge_attr) * C.view(-1, 1)\n", - "\n", - " x = self.lin1(x)\n", - " x = self.propagate(edge_index, x=x, W=W)\n", - " x = self.lin2(x)\n", - " return x\n", - "\n", - " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", - " return x_j * W\n", - "\n", - "\n", - "class InteractionBlock(torch.nn.Module):\n", - " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", - " super(InteractionBlock, self).__init__()\n", - " mlp = Sequential(\n", - " Linear(num_gaussians, num_filters),\n", - " ShiftedSoftplus(),\n", - " Linear(num_filters, num_filters),\n", - " )\n", - " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", - " self.act = ShiftedSoftplus()\n", - " self.lin = Linear(hidden_channels, hidden_channels)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " x = self.conv(x, edge_index, edge_length, edge_attr)\n", - " x = self.act(x)\n", - " x = self.lin(x)\n", - " return x\n", - "\n", - "\n", - "class SchNetEncoder(Module):\n", - " def __init__(\n", - " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.hidden_channels = hidden_channels\n", - " self.num_filters = num_filters\n", - " self.num_interactions = num_interactions\n", - " self.cutoff = cutoff\n", - "\n", - " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", - "\n", - " self.interactions = ModuleList()\n", - " for _ in range(num_interactions):\n", - " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", - " self.interactions.append(block)\n", - "\n", - " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", - " if embed_node:\n", - " assert z.dim() == 1 and z.dtype == torch.long\n", - " h = self.embedding(z)\n", - " else:\n", - " h = z\n", - " for interaction in self.interactions:\n", - " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", - "\n", - " return h\n", - "\n", - "\n", - "class GINEConv(MessagePassing):\n", - " \"\"\"\n", - " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", - " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", - " \"\"\"\n", - "\n", - " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", - " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", - " self.nn = mlp\n", - " self.initial_eps = eps\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " if train_eps:\n", - " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", - " else:\n", - " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", - "\n", - " def forward(\n", - " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", - " ) -> torch.Tensor:\n", - " \"\"\"\"\"\"\n", - " if isinstance(x, torch.Tensor):\n", - " x: OptPairTensor = (x, x)\n", - "\n", - " # Node and edge feature dimensionalites need to match.\n", - " if isinstance(edge_index, torch.Tensor):\n", - " assert edge_attr is not None\n", - " assert x[0].size(-1) == edge_attr.size(-1)\n", - " elif isinstance(edge_index, SparseTensor):\n", - " assert x[0].size(-1) == edge_index.size(-1)\n", - "\n", - " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", - " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", - "\n", - " x_r = x[1]\n", - " if x_r is not None:\n", - " out += (1 + self.eps) * x_r\n", - "\n", - " return self.nn(out)\n", - "\n", - " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", - " if self.activation:\n", - " return self.activation(x_j + edge_attr)\n", - " else:\n", - " return x_j + edge_attr\n", - "\n", - " def __repr__(self):\n", - " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", - "\n", - "\n", - "class GINEncoder(torch.nn.Module):\n", - " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", - " super().__init__()\n", - "\n", - " self.hidden_dim = hidden_dim\n", - " self.num_convs = num_convs\n", - " self.short_cut = short_cut\n", - " self.concat_hidden = concat_hidden\n", - " self.node_emb = nn.Embedding(100, hidden_dim)\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " self.convs = nn.ModuleList()\n", - " for i in range(self.num_convs):\n", - " self.convs.append(\n", - " GINEConv(\n", - " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", - " activation=activation,\n", - " )\n", - " )\n", - "\n", - " def forward(self, z, edge_index, edge_attr):\n", - " \"\"\"\n", - " Input:\n", - " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", - " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", - " Output:\n", - " node_feature: graph feature\n", - " \"\"\"\n", - "\n", - " node_attr = self.node_emb(z) # (num_node, hidden)\n", - "\n", - " hiddens = []\n", - " conv_input = node_attr # (num_node, hidden)\n", - "\n", - " for conv_idx, conv in enumerate(self.convs):\n", - " hidden = conv(conv_input, edge_index, edge_attr)\n", - " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", - " hidden = self.activation(hidden)\n", - " assert hidden.shape == conv_input.shape\n", - " if self.short_cut and hidden.shape == conv_input.shape:\n", - " hidden += conv_input\n", - "\n", - " hiddens.append(hidden)\n", - " conv_input = hidden\n", - "\n", - " if self.concat_hidden:\n", - " node_feature = torch.cat(hiddens, dim=-1)\n", - " else:\n", - " node_feature = hiddens[-1]\n", - "\n", - " return node_feature\n", - "\n", - "\n", - "class MLPEdgeEncoder(Module):\n", - " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", - " super().__init__()\n", - " self.hidden_dim = hidden_dim\n", - " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", - " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", - "\n", - " @property\n", - " def out_channels(self):\n", - " return self.hidden_dim\n", - "\n", - " def forward(self, edge_length, edge_type):\n", - " \"\"\"\n", - " Input:\n", - " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", - " Returns:\n", - " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", - " \"\"\"\n", - " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", - " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", - " return d_emb * edge_attr # (num_edge, hidden)\n", - "\n", - "\n", - "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", - " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", - " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", - " return h_pair\n", - "\n", - "\n", - "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", - " \"\"\"\n", - " Args:\n", - " num_nodes: Number of atoms.\n", - " edge_index: Bond indices of the original graph.\n", - " edge_type: Bond types of the original graph.\n", - " order: Extension order.\n", - " Returns:\n", - " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", - " \"\"\"\n", - "\n", - " def binarize(x):\n", - " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", - "\n", - " def get_higher_order_adj_matrix(adj, order):\n", - " \"\"\"\n", - " Args:\n", - " adj: (N, N)\n", - " type_mat: (N, N)\n", - " Returns:\n", - " Following attributes will be updated:\n", - " - edge_index\n", - " - edge_type\n", - " Following attributes will be added to the data object:\n", - " - bond_edge_index: Original edge_index.\n", - " \"\"\"\n", - " adj_mats = [\n", - " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", - " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", - " ]\n", - "\n", - " for i in range(2, order + 1):\n", - " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", - " order_mat = torch.zeros_like(adj)\n", - "\n", - " for i in range(1, order + 1):\n", - " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", - "\n", - " return order_mat\n", - "\n", - " num_types = 22\n", - " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", - " # from rdkit.Chem.rdchem import BondType as BT\n", - " N = num_nodes\n", - " adj = to_dense_adj(edge_index).squeeze(0)\n", - " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", - "\n", - " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", - " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", - " assert (type_mat * type_highorder == 0).all()\n", - " type_new = type_mat + type_highorder\n", - "\n", - " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", - " _, edge_order = dense_to_sparse(adj_order)\n", - "\n", - " # data.bond_edge_index = data.edge_index # Save original edges\n", - " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", - " assert edge_type.dim() == 1\n", - " N = pos.size(0)\n", - "\n", - " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", - "\n", - " if is_sidechain is None:\n", - " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", - " else:\n", - " # fetch sidechain and its batch index\n", - " is_sidechain = is_sidechain.bool()\n", - " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", - " sidechain_pos = pos[is_sidechain]\n", - " sidechain_index = dummy_index[is_sidechain]\n", - " sidechain_batch = batch[is_sidechain]\n", - "\n", - " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", - " r_edge_index_x = assign_index[1]\n", - " r_edge_index_y = assign_index[0]\n", - " r_edge_index_y = sidechain_index[r_edge_index_y]\n", - "\n", - " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", - " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", - " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", - " # delete self loop\n", - " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", - "\n", - " rgraph_adj = torch.sparse.LongTensor(\n", - " rgraph_edge_index,\n", - " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", - " torch.Size([N, N]),\n", - " )\n", - "\n", - " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", - "\n", - " new_edge_index = composed_adj.indices()\n", - " new_edge_type = composed_adj.values().long()\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def extend_graph_order_radius(\n", - " num_nodes,\n", - " pos,\n", - " edge_index,\n", - " edge_type,\n", - " batch,\n", - " order=3,\n", - " cutoff=10.0,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - "):\n", - " if extend_order:\n", - " edge_index, edge_type = _extend_graph_order(\n", - " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", - " )\n", - "\n", - " if extend_radius:\n", - " edge_index, edge_type = _extend_to_radius_graph(\n", - " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", - " )\n", - "\n", - " return edge_index, edge_type\n", - "\n", - "\n", - "def get_distance(pos, edge_index):\n", - " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", - "\n", - "\n", - "def graph_field_network(score_d, pos, edge_index, edge_length):\n", - " \"\"\"\n", - " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", - " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", - " \"\"\"\n", - " N = pos.size(0)\n", - " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", - " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", - " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", - " ) # (N, 3)\n", - " return score_pos\n", - "\n", - "\n", - "def clip_norm(vec, limit, p=2):\n", - " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", - " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", - " return vec * denom\n", - "\n", - "\n", - "def is_local_edge(edge_type):\n", - " return edge_type > 0\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QWrHJFcYXyUB" - }, - "source": [ - "Main model class!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MCeZA1qQXzoK" - }, - "outputs": [], - "source": [ - "class MoleculeGNN(ModelMixin, ConfigMixin):\n", - " @register_to_config\n", - " def __init__(\n", - " self,\n", - " hidden_dim=128,\n", - " num_convs=6,\n", - " num_convs_local=4,\n", - " cutoff=10.0,\n", - " mlp_act=\"relu\",\n", - " edge_order=3,\n", - " edge_encoder=\"mlp\",\n", - " smooth_conv=True,\n", - " ):\n", - " super().__init__()\n", - " self.cutoff = cutoff\n", - " self.edge_encoder = edge_encoder\n", - " self.edge_order = edge_order\n", - "\n", - " \"\"\"\n", - " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", - " in SchNetEncoder\n", - " \"\"\"\n", - " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - "\n", - " \"\"\"\n", - " The graph neural network that extracts node-wise features.\n", - " \"\"\"\n", - " self.encoder_global = SchNetEncoder(\n", - " hidden_channels=hidden_dim,\n", - " num_filters=hidden_dim,\n", - " num_interactions=num_convs,\n", - " edge_channels=self.edge_encoder_global.out_channels,\n", - " cutoff=cutoff,\n", - " smooth=smooth_conv,\n", - " )\n", - " self.encoder_local = GINEncoder(\n", - " hidden_dim=hidden_dim,\n", - " num_convs=num_convs_local,\n", - " )\n", - "\n", - " \"\"\"\n", - " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", - " gradients w.r.t. edge_length (out_dim = 1).\n", - " \"\"\"\n", - " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " \"\"\"\n", - " Incorporate parameters together\n", - " \"\"\"\n", - " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", - " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", - "\n", - " def _forward(\n", - " self,\n", - " atom_type,\n", - " pos,\n", - " bond_index,\n", - " bond_type,\n", - " batch,\n", - " time_step, # NOTE, model trained without timestep performed best\n", - " edge_index=None,\n", - " edge_type=None,\n", - " edge_length=None,\n", - " return_edges=False,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " atom_type: Types of atoms, (N, ).\n", - " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", - " bond_type: Bond types, (E, ).\n", - " batch: Node index to graph index, (N, ).\n", - " \"\"\"\n", - " N = atom_type.size(0)\n", - " if edge_index is None or edge_type is None or edge_length is None:\n", - " edge_index, edge_type = extend_graph_order_radius(\n", - " num_nodes=N,\n", - " pos=pos,\n", - " edge_index=bond_index,\n", - " edge_type=bond_type,\n", - " batch=batch,\n", - " order=self.edge_order,\n", - " cutoff=self.cutoff,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " is_sidechain=is_sidechain,\n", - " )\n", - " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", - " local_edge_mask = is_local_edge(edge_type) # (E, )\n", - "\n", - " # with the parameterization of NCSNv2\n", - " # DDPM loss implicit handle the noise variance scale conditioning\n", - " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", - "\n", - " # Encoding global\n", - " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - "\n", - " # Global\n", - " node_attr_global = self.encoder_global(\n", - " z=atom_type,\n", - " edge_index=edge_index,\n", - " edge_length=edge_length,\n", - " edge_attr=edge_attr_global,\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_global = assemble_atom_pair_feature(\n", - " node_attr=node_attr_global,\n", - " edge_index=edge_index,\n", - " edge_attr=edge_attr_global,\n", - " ) # (E_global, 2H)\n", - " # Invariant features of edges (radius graph, global)\n", - " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", - "\n", - " # Encoding local\n", - " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - " # edge_attr += temb_edge\n", - "\n", - " # Local\n", - " node_attr_local = self.encoder_local(\n", - " z=atom_type,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_local = assemble_atom_pair_feature(\n", - " node_attr=node_attr_local,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " ) # (E_local, 2H)\n", - "\n", - " # Invariant features of edges (bond graph, local)\n", - " if isinstance(sigma_edge, torch.Tensor):\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", - " 1.0 / sigma_edge[local_edge_mask]\n", - " ) # (E_local, 1)\n", - " else:\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", - "\n", - " if return_edges:\n", - " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", - " else:\n", - " return edge_inv_global, edge_inv_local\n", - "\n", - " def forward(\n", - " self,\n", - " sample,\n", - " timestep: Union[torch.Tensor, float, int],\n", - " return_dict: bool = True,\n", - " sigma=1.0,\n", - " global_start_sigma=0.5,\n", - " w_global=1.0,\n", - " extend_order=False,\n", - " extend_radius=True,\n", - " clip_local=None,\n", - " clip_global=1000.0,\n", - " ) -> Union[MoleculeGNNOutput, Tuple]:\n", - " r\"\"\"\n", - " Args:\n", - " sample: packed torch geometric object\n", - " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", - " return_dict (`bool`, *optional*, defaults to `True`):\n", - " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", - " Returns:\n", - " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", - " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", - " \"\"\"\n", - "\n", - " # unpack sample\n", - " atom_type = sample.atom_type\n", - " bond_index = sample.edge_index\n", - " bond_type = sample.edge_type\n", - " num_graphs = sample.num_graphs\n", - " pos = sample.pos\n", - "\n", - " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", - "\n", - " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", - " atom_type=atom_type,\n", - " pos=sample.pos,\n", - " bond_index=bond_index,\n", - " bond_type=bond_type,\n", - " batch=sample.batch,\n", - " time_step=timesteps,\n", - " return_edges=True,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " ) # (E_global, 1), (E_local, 1)\n", - "\n", - " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", - " node_eq_local = graph_field_network(\n", - " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", - " )\n", - " if clip_local is not None:\n", - " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", - "\n", - " # Global\n", - " if sigma < global_start_sigma:\n", - " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", - " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", - " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", - " else:\n", - " node_eq_global = 0\n", - "\n", - " # Sum\n", - " eps_pos = node_eq_local + node_eq_global * w_global\n", - "\n", - " if not return_dict:\n", - " return (-eps_pos,)\n", - "\n", - " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CCIrPYSJj9wd" - }, - "source": [ - "### Load pretrained model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YdrAr6Ch--Ab" - }, - "source": [ - "#### Load a model\n", - "The model used is a design an\n", - "equivariant convolutional layer, named graph field network (GFN).\n", - "\n", - "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 172, - "referenced_widgets": [ - "d90f304e9560472eacfbdd11e46765eb", - "1c6246f15b654f4daa11c9bcf997b78c", - "c2321b3bff6f490ca12040a20308f555", - "b7feb522161f4cf4b7cc7c1a078ff12d", - "e2d368556e494ae7ae4e2e992af2cd4f", - "bbef741e76ec41b7ab7187b487a383df", - "561f742d418d4721b0670cc8dd62e22c", - "872915dd1bb84f538c44e26badabafdd", - "d022575f1fa2446d891650897f187b4d", - "fdc393f3468c432aa0ada05e238a5436", - "2c9362906e4b40189f16d14aa9a348da", - "6010fc8daa7a44d5aec4b830ec2ebaa1", - "7e0bb1b8d65249d3974200686b193be2", - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "6526646be5ed415c84d1245b040e629b", - "24d31fc3576e43dd9f8301d2ef3a37ab", - "2918bfaadc8d4b1a9832522c40dfefb8", - "a4bfdca35cc54dae8812720f1b276a08", - "e4901541199b45c6a18824627692fc39", - "f915cf874246446595206221e900b2fe", - "a9e388f22a9742aaaf538e22575c9433", - "42f6c3db29d7484ba6b4f73590abd2f4" - ] - }, - "id": "DyCo0nsqjbml", - "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d90f304e9560472eacfbdd11e46765eb", - "version_major": 2, - "version_minor": 0 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1g_6zOabItDk" }, - "text/plain": [ - "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", - "\n", - "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", - "\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "\n", - "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "dataset = torch.load('/content/molecules.pkl')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QZcmy1EvKQRk" - }, - "source": [ - "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "JVjz6iH_H6Eh", - "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" - }, - "outputs": [ { - "data": { - "text/plain": [ - "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + "cell_type": "markdown", + "metadata": { + "id": "VfthW90vI0nw" + }, + "source": [ + "Install Conda for some more complex dependencies for geometric networks." ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vHNiZAUxNgoy" - }, - "source": [ - "## Run the diffusion process" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZ1KZrxKqENg" - }, - "source": [ - "#### Helper Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s240tYueqKKf" - }, - "outputs": [], - "source": [ - "import copy\n", - "import os\n", - "\n", - "from torch_geometric.data import Batch, Data\n", - "from torch_scatter import scatter_mean\n", - "from tqdm import tqdm\n", - "\n", - "\n", - "def repeat_data(data: Data, num_repeat) -> Batch:\n", - " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", - " return Batch.from_data_list(datas)\n", - "\n", - "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", - " datas = batch.to_data_list()\n", - " new_data = []\n", - " for i in range(num_repeat):\n", - " new_data += copy.deepcopy(datas)\n", - " return Batch.from_data_list(new_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMnQTk0eqT7Z" - }, - "source": [ - "#### Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WYGkzqgzrHmF" - }, - "outputs": [], - "source": [ - "num_samples = 1 # solutions per molecule\n", - "num_molecules = 3\n", - "\n", - "DEVICE = 'cuda'\n", - "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", - "# constants for inference\n", - "w_global = 0.5 #0,.3 for qm9\n", - "global_start_sigma = 0.5\n", - "eta = 1.0\n", - "clip_local = None\n", - "clip_pos = None\n", - "\n", - "# constands for data handling\n", - "save_traj = False\n", - "save_data = False\n", - "output_dir = '/content/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-xD5bJ3SqM7t" - }, - "source": [ - "#### Generate samples!\n", - "Note that the 3d representation of a molecule is referred to as the **conformation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "x9xuLUNg26z1", - "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " after removing the cwd from sys.path.\n", - "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" - ] - } - ], - "source": [ - "results = []\n", - "\n", - "# define sigmas\n", - "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", - "sigmas = sigmas.to(DEVICE)\n", - "\n", - "for count, data in enumerate(tqdm(dataset)):\n", - " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", - "\n", - " data_input = data.clone()\n", - " data_input['pos_ref'] = None\n", - " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", - "\n", - " # initial configuration\n", - " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", - "\n", - " # for logging animation of denoising\n", - " pos_traj = []\n", - " with torch.no_grad():\n", - "\n", - " # scale initial sample\n", - " pos = pos_init * sigmas[-1]\n", - " for t in scheduler.timesteps:\n", - " batch.pos = pos\n", - "\n", - " # generate geometry with model, then filter it\n", - " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", - "\n", - " # Update\n", - " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", - "\n", - " pos = reconstructed_pos\n", - "\n", - " if torch.isnan(pos).any():\n", - " print(\"NaN detected. Please restart.\")\n", - " raise FloatingPointError()\n", - "\n", - " # recenter graph of positions for next iteration\n", - " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", - "\n", - " # optional clipping\n", - " if clip_pos is not None:\n", - " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", - " pos_traj.append(pos.clone().cpu())\n", - "\n", - " pos_gen = pos.cpu()\n", - " if save_traj:\n", - " pos_gen_traj = pos_traj.cpu()\n", - " data.pos_gen = torch.stack(pos_gen_traj)\n", - " else:\n", - " data.pos_gen = pos_gen\n", - " results.append(data)\n", - "\n", - "\n", - "if save_data:\n", - " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", - "\n", - " with open(save_path, 'wb') as f:\n", - " pickle.dump(results, f)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fSApwSaZNndW" - }, - "source": [ - "## Render the results!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d47Zxo2OKdgZ" - }, - "source": [ - "This function allows us to render 3d in colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e9Cd0kCAv9b8" - }, - "outputs": [], - "source": [ - "from google.colab import output\n", - "\n", - "\n", - "output.enable_custom_widget_manager()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RjaVuR15NqzF" - }, - "source": [ - "### Helper functions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "28rBYa9NKhlz" - }, - "source": [ - "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LKdKdwxcyTQ6" - }, - "outputs": [], - "source": [ - "from copy import deepcopy\n", - "\n", - "\n", - "def set_rdmol_positions(rdkit_mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " mol = deepcopy(rdkit_mol)\n", - " set_rdmol_positions_(mol, pos)\n", - " return mol\n", - "\n", - "def set_rdmol_positions_(mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " for i in range(pos.shape[0]):\n", - " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", - " return mol\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NuE10hcpKmzK" - }, - "source": [ - "Process the generated data to make it easy to view." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KieVE1vc0_Vs", - "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "collect 5 generated molecules in `mols`\n" - ] - } - ], - "source": [ - "# the model can generate multiple conformations per 2d geometry\n", - "num_gen = results[0]['pos_gen'].shape[0]\n", - "\n", - "# init storage objects\n", - "mols_gen = []\n", - "mols_orig = []\n", - "for to_process in results:\n", - "\n", - " # store the reference 3d position\n", - " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # store the generated 3d position\n", - " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # copy data to new object\n", - " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", - "\n", - " # append results\n", - " mols_gen.append(new_mol)\n", - " mols_orig.append(to_process.rdmol)\n", - "\n", - "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tin89JwMKp4v" - }, - "source": [ - "Import tools to visualize the 2d chemical diagram of the molecule." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yqV6gllSZn38" - }, - "outputs": [], - "source": [ - "from IPython.display import SVG, display\n", - "from rdkit import Chem\n", - "from rdkit.Chem.Draw import rdMolDraw2D as MD2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TFNKmGddVoOk" - }, - "source": [ - "Select molecule to visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KzuwLlrrVaGc" - }, - "outputs": [], - "source": [ - "idx = 0\n", - "assert idx < len(results), \"selected molecule that was not generated\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hkb8w0_SNtU8" - }, - "source": [ - "### Viewing" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I3R4QBQeKttN" - }, - "source": [ - "This 2D rendering is the equivalent of the **input to the model**!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 321 - }, - "id": "gkQRWjraaKex", - "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" - }, - "outputs": [ - { - "data": { - "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", - "text/plain": [ - "" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2WNFzSnbiE0k", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q condacolab" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", - "molSize=(450,300)\n", - "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", - "drawer.DrawMolecule(mc)\n", - "drawer.FinishDrawing()\n", - "svg = drawer.GetDrawingText()\n", - "display(SVG(svg.replace('svg:','')))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z4FDMYMxKw2I" - }, - "source": [ - "Generate the 3d molecule!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "695ab5bbf30a4ab19df1f9f33469f314", - "eac6a8dcdc9d4335a2e51031793ead29" - ] - }, - "id": "aT1Bkb8YxJfV", - "outputId": "b98870ae-049d-4386-b676-166e9526bda2" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "695ab5bbf30a4ab19df1f9f33469f314", - "version_major": 2, - "version_minor": 0 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" }, - "text/plain": [] - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + "source": [ + "Setup Conda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] } - } - } - }, - "output_type": "display_data" - } - ], - "source": [ - "from nglview import show_rdkit as show" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "be446195da2b4ff2aec21ec5ff963a54", - "c6596896148b4a8a9c57963b67c7782f", - "2489b5e5648541fbbdceadb05632a050", - "01e0ba4e5da04914b4652b8d58565d7b", - "c30e6c2f3e2a44dbbb3d63bd519acaa4", - "f31c6e40e9b2466a9064a2669933ecd5", - "19308ccac642498ab8b58462e3f1b0bb", - "4a081cdc2ec3421ca79dd933b7e2b0c4", - "e5c0d75eb5e1447abd560c8f2c6017e1", - "5146907ef6764654ad7d598baebc8b58", - "144ec959b7604a2cabb5ca46ae5e5379", - "abce2a80e6304df3899109c6d6cac199", - "65195cb7a4134f4887e9dd19f3676462" - ] - }, - "id": "pxtq8I-I18C-", - "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "be446195da2b4ff2aec21ec5ff963a54", - "version_major": 2, - "version_minor": 0 + ], + "source": [ + "import condacolab\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" }, - "text/plain": [ - "NGLWidget()" + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" ] - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] } - } + ], + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ], + "metadata": { + "id": "QDS6FPZ0Tu5b" } - }, - "output_type": "display_data" - } - ], - "source": [ - "# new molecule\n", - "show(mols_gen[idx])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KJr4h2mwXeTo" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "01e0ba4e5da04914b4652b8d58565d7b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", - "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + }, + { + "cell_type": "code", + "source": [ + "!rm /usr/local/conda-meta/pinned" ], - "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" - } - }, - "144ec959b7604a2cabb5ca46ae5e5379": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "19308ccac642498ab8b58462e3f1b0bb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1c6246f15b654f4daa11c9bcf997b78c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", - "placeholder": "​", - "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", - "value": "Downloading: 100%" - } - }, - "2489b5e5648541fbbdceadb05632a050": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ButtonModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ButtonView", - "button_style": "", - "description": "", - "disabled": false, - "icon": "compress", - "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", - "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", - "tooltip": "" - } - }, - "24d31fc3576e43dd9f8301d2ef3a37ab": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2918bfaadc8d4b1a9832522c40dfefb8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c9362906e4b40189f16d14aa9a348da": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "42f6c3db29d7484ba6b4f73590abd2f4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "4a081cdc2ec3421ca79dd933b7e2b0c4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "SliderStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "SliderStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "", - "handle_color": null - } - }, - "5146907ef6764654ad7d598baebc8b58": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "IntSliderModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "IntSliderModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "IntSliderView", - "continuous_update": true, - "description": "", - "description_tooltip": null, - "disabled": false, - "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", - "max": 0, - "min": 0, - "orientation": "horizontal", - "readout": true, - "readout_format": "d", - "step": 1, - "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", - "value": 0 - } - }, - "561f742d418d4721b0670cc8dd62e22c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6010fc8daa7a44d5aec4b830ec2ebaa1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", - "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + "metadata": { + "id": "dq1lxR10TtrR", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D5ukfCOWfjzK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } ], - "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" - } - }, - "65195cb7a4134f4887e9dd19f3676462": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ButtonStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "button_color": null, - "font_weight": "" - } - }, - "6526646be5ed415c84d1245b040e629b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", - "placeholder": "​", - "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", - "value": " 401/401 [00:00<00:00, 13.5kB/s]" - } - }, - "695ab5bbf30a4ab19df1f9f33469f314": { - "model_module": "nglview-js-widgets", - "model_module_version": "3.0.1", - "model_name": "ColormakerRegistryModel", - "state": { - "_dom_classes": [], - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "ColormakerRegistryModel", - "_msg_ar": [], - "_msg_q": [], - "_ready": false, - "_view_count": null, - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "ColormakerRegistryView", - "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" - } - }, - "7e0bb1b8d65249d3974200686b193be2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", - "placeholder": "​", - "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", - "value": "Downloading: 100%" - } - }, - "872915dd1bb84f538c44e26badabafdd": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a4bfdca35cc54dae8812720f1b276a08": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "a9e388f22a9742aaaf538e22575c9433": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "abce2a80e6304df3899109c6d6cac199": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "34px" - } - }, - "b7feb522161f4cf4b7cc7c1a078ff12d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", - "placeholder": "​", - "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", - "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" - } - }, - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", - "max": 401, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", - "value": 401 - } - }, - "bbef741e76ec41b7ab7187b487a383df": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be446195da2b4ff2aec21ec5ff963a54": { - "model_module": "nglview-js-widgets", - "model_module_version": "3.0.1", - "model_name": "NGLModel", - "state": { - "_camera_orientation": [ - -15.519693580202304, - -14.065056548036177, - -23.53197484807691, - 0, - -23.357853515109753, - 20.94055073042662, - 2.888695042134944, - 0, - 14.352363398292775, - 18.870825741878015, - -20.744689572909344, - 0, - 0.2724999189376831, - 0.6940000057220459, - -0.3734999895095825, - 1 + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mgQA_XN-XGY2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001b[K\n", + "remote: Counting objects: 100% (40/40), done.\u001b[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } ], - "_camera_str": "orthographic", - "_dom_classes": [], - "_gui_theme": null, - "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", - "_igui": null, - "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "NGLModel", - "_ngl_color_dict": {}, - "_ngl_coordinate_resource": {}, - "_ngl_full_stage_parameters": { - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "backgroundColor": "white", - "cameraEyeSep": 0.3, - "cameraFov": 40, - "cameraType": "perspective", - "clipDist": 10, - "clipFar": 100, - "clipNear": 0, - "fogFar": 100, - "fogNear": 50, - "hoverTimeout": 0, - "impostor": true, - "lightColor": 14540253, - "lightIntensity": 1, - "mousePreset": "default", - "panSpeed": 1, - "quality": "medium", - "rotateSpeed": 2, - "sampleLevel": 0, - "tooltip": true, - "workerDefault": true, - "zoomSpeed": 1.2 + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers\n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" }, - "_ngl_msg_archive": [ - { - "args": [ - { - "binary": false, - "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", - "type": "blob" - } - ], - "kwargs": { - "defaultRepresentation": true, - "ext": "pdb" + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZt7BNi1e1PA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "True\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'1.8.2'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 8 + } + ], + "source": [ + "import torch\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0CPv_NvehRz3", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jcl8GCS2mz6t", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] }, - "methodName": "loadFile", - "reconstruc_color_scheme": false, - "target": "Stage", - "type": "call_method" - } + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } + } + }, + "metadata": {} + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a diffusion model" + ], + "metadata": { + "id": "8t8_e_uVLdKB" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Model class(es)" + ], + "metadata": { + "id": "G0rMncVtNSqU" + } + }, + { + "cell_type": "markdown", + "source": [ + "Imports" + ], + "metadata": { + "id": "L5FEXz5oXkzt" + } + }, + { + "cell_type": "code", + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput\n" + ], + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Helper classes" + ], + "metadata": { + "id": "EzJQXPN_XrMX" + } + }, + { + "cell_type": "code", + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.Tensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0\n" + ], + "metadata": { + "id": "oR1Y56QiLY90" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Main model class!" + ], + "metadata": { + "id": "QWrHJFcYXyUB" + } + }, + { + "cell_type": "code", + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" ], - "_ngl_original_stage_parameters": { - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "backgroundColor": "white", - "cameraEyeSep": 0.3, - "cameraFov": 40, - "cameraType": "perspective", - "clipDist": 10, - "clipFar": 100, - "clipNear": 0, - "fogFar": 100, - "fogNear": 50, - "hoverTimeout": 0, - "impostor": true, - "lightColor": 14540253, - "lightIntensity": 1, - "mousePreset": "default", - "panSpeed": 1, - "quality": "medium", - "rotateSpeed": 2, - "sampleLevel": 0, - "tooltip": true, - "workerDefault": true, - "zoomSpeed": 1.2 + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" + }, + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DyCo0nsqjbml", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" }, - "_ngl_repr_dict": { - "0": { - "0": { - "params": { - "aspectRatio": 1.5, - "assembly": "default", - "bondScale": 0.3, - "bondSpacing": 0.75, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] } - }, - "1": { - "0": { - "params": { - "aspectRatio": 1.5, - "assembly": "default", - "bondScale": 0.3, - "bondSpacing": 0.75, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load('/content/molecules.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JVjz6iH_H6Eh", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + ] }, - "clipNear": 0, - "clipRadius": 0, - "colorMode": "hcl", - "colorReverse": false, - "colorScale": "", - "colorScheme": "element", - "colorValue": 9474192, - "cylinderOnly": false, - "defaultAssembly": "", - "depthWrite": true, - "diffuse": 16777215, - "diffuseInterior": false, - "disableImpostor": false, - "disablePicking": false, - "flatShaded": false, - "interiorColor": 2236962, - "interiorDarkening": 0, - "lazy": false, - "lineOnly": false, - "linewidth": 2, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run the diffusion process" + ], + "metadata": { + "id": "vHNiZAUxNgoy" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "from torch_geometric.data import Data, Batch\n", + "from torch_scatter import scatter_add, scatter_mean\n", + "from tqdm import tqdm\n", + "import copy\n", + "import os\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = 'cuda'\n", + "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 #0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None\n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = '/content/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x9xuLUNg26z1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)):\n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input['pos_ref'] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + "\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", + "\n", + " with open(save_path, 'wb') as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Render the results!" + ], + "metadata": { + "id": "fSApwSaZNndW" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Helper functions" + ], + "metadata": { + "id": "RjaVuR15NqzF" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KieVE1vc0_Vs", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "collect 5 generated molecules in `mols`\n" + ] + } + ], + "source": [ + "# the model can generate multiple conformations per 2d geometry\n", + "num_gen = results[0]['pos_gen'].shape[0]\n", + "\n", + "# init storage objects\n", + "mols_gen = []\n", + "mols_orig = []\n", + "for to_process in results:\n", + "\n", + " # store the reference 3d position\n", + " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # store the generated 3d position\n", + " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # copy data to new object\n", + " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", + "\n", + " # append results\n", + " mols_gen.append(new_mol)\n", + " mols_orig.append(to_process.rdmol)\n", + "\n", + "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tin89JwMKp4v" + }, + "source": [ + "Import tools to visualize the 2d chemical diagram of the molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yqV6gllSZn38" + }, + "outputs": [], + "source": [ + "from rdkit.Chem import AllChem\n", + "from rdkit import Chem\n", + "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", + "from IPython.display import SVG, display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFNKmGddVoOk" + }, + "source": [ + "Select molecule to visualize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KzuwLlrrVaGc" + }, + "outputs": [], + "source": [ + "idx = 0\n", + "assert idx < len(results), \"selected molecule that was not generated\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Viewing" + ], + "metadata": { + "id": "hkb8w0_SNtU8" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3R4QBQeKttN" + }, + "source": [ + "This 2D rendering is the equivalent of the **input to the model**!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gkQRWjraaKex", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 321 + }, + "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" }, - "metalness": 0, - "multipleBond": "off", - "opacity": 1, - "openEnded": true, - "quality": "high", - "radialSegments": 20, - "radiusData": {}, - "radiusScale": 2, - "radiusSize": 0.15, - "radiusType": "size", - "roughness": 0.4, - "sele": "", - "side": "double", - "sphereDetail": 2, - "useInteriorColor": true, - "visible": true, - "wireframe": false - }, - "type": "ball+stick" + "metadata": {} } - } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", + "molSize=(450,300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace('svg:','')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" }, - "_ngl_serialize": false, - "_ngl_version": "", - "_ngl_view_id": [ - "FB989FD1-5B9C-446B-8914-6B58AF85446D" + "source": [ + "Generate the 3d molecule!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aT1Bkb8YxJfV", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "695ab5bbf30a4ab19df1f9f33469f314" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } ], - "_player_dict": {}, - "_scene_position": {}, - "_scene_rotation": {}, - "_synced_model_ids": [], - "_synced_repr_model_ids": [], - "_view_count": null, - "_view_height": "", - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "NGLView", - "_view_width": "", - "background": "white", - "frame": 0, - "gui_style": null, - "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", - "max_frame": 0, - "n_components": 2, - "picked": {} - } - }, - "c2321b3bff6f490ca12040a20308f555": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", - "max": 3271865, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", - "value": 3271865 - } - }, - "c30e6c2f3e2a44dbbb3d63bd519acaa4": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "c6596896148b4a8a9c57963b67c7782f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d022575f1fa2446d891650897f187b4d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d90f304e9560472eacfbdd11e46765eb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", - "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", - "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxtq8I-I18C-", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "NGLWidget()" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "be446195da2b4ff2aec21ec5ff963a54" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } ], - "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" - } - }, - "e2d368556e494ae7ae4e2e992af2cd4f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e4901541199b45c6a18824627692fc39": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e5c0d75eb5e1447abd560c8f2c6017e1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "PlayModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "PlayModel", - "_playing": false, - "_repeat": false, - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "PlayView", - "description": "", - "description_tooltip": null, - "disabled": false, - "interval": 100, - "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", - "max": 0, - "min": 0, - "show_repeat": true, - "step": 1, - "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", - "value": 0 - } - }, - "eac6a8dcdc9d4335a2e51031793ead29": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f31c6e40e9b2466a9064a2669933ecd5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "f915cf874246446595206221e900b2fe": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fdc393f3468c432aa0ada05e238a5436": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + ], + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + ], + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_name": "ColormakerRegistryModel", + "model_module_version": "3.0.1", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_name": "NGLModel", + "model_module_version": "3.0.1", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292777, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 + ], + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_msg_archive": [ + { + "target": "Stage", + "type": "call_method", + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "args": [ + { + "type": "blob", + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "binary": false + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" + } + } + ], + "_ngl_original_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_repr_dict": { + "0": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + }, + "1": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + } + }, + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" + ], + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + ], + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "SliderStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "PlayModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntSliderModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + } + } } - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb index 4930253ff66e..571f1a0323a2 100644 --- a/examples/research_projects/gligen/demo.ipynb +++ b/examples/research_projects/gligen/demo.ipynb @@ -26,7 +26,8 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "from diffusers import StableDiffusionGLIGENPipeline" + "import torch\n", + "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline" ] }, { @@ -35,17 +36,16 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import CLIPTextModel, CLIPTokenizer\n", - "\n", + "import os\n", "import diffusers\n", "from diffusers import (\n", " AutoencoderKL,\n", " DDPMScheduler,\n", - " EulerDiscreteScheduler,\n", " UNet2DConditionModel,\n", + " UniPCMultistepScheduler,\n", + " EulerDiscreteScheduler,\n", ")\n", - "\n", - "\n", + "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "\n", "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", @@ -122,7 +122,6 @@ "\n", "import numpy as np\n", "\n", - "\n", "boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = boxes / 512\n", "boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\n",